UNPKG

@tensorflow-models/coco-ssd

Version:

Object detection model (coco-ssd) in TensorFlow.js

485 lines (443 loc) 18.1 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {DataType, Tensor, tidy, util} from '@tensorflow/tfjs-core'; // tslint:disable-next-line:max-line-length import {NamedTensorMap, NamedTensorsMap, TensorArrayMap, TensorInfo} from '../data/types'; // tslint:disable-next-line:max-line-length import {getNodeNameAndIndex, getParamValue, getTensor, getTensorsForCurrentContenxt, parseNodeName} from '../operations/executors/utils'; import {executeOp} from '../operations/operation_executor'; import {Graph, Node} from '../operations/types'; import {ExecutionContext, ExecutionContextInfo} from './execution_context'; interface NodeWithContexts { contexts: ExecutionContextInfo[]; node: Node; } export class GraphExecutor { private compiledMap: Map<string, Node[]> = new Map(); private _weightMap: NamedTensorsMap = {}; private weightIds: number[]; private placeholders: Node[]; private _outputs: Node[]; private SEPERATOR = ','; get weightMap(): NamedTensorsMap { return this._weightMap; } set weightMap(weightMap: NamedTensorsMap) { const weightIds = Object.keys(weightMap).map( key => weightMap[key].map(tensor => tensor.id)); this.weightIds = [].concat.apply([], weightIds); this._weightMap = weightMap; } get inputs(): TensorInfo[] { return this.placeholders.map(node => { return { name: node.name, shape: node.params['shape'] ? node.params['shape'].value as number[] : undefined, dtype: node.params['dtype'] ? node.params['dtype'].value as DataType : undefined }; }); } get outputs(): TensorInfo[] { return this._outputs.map(node => { return { name: node.name, shape: node.params['shape'] ? node.params['shape'].value as number[] : undefined, dtype: node.params['dtype'] ? node.params['dtype'].value as DataType : undefined }; }); } get inputNodes(): string[] { return this.placeholders.map(node => node.name); } get outputNodes(): string[] { return this.outputs.map(node => node.name); } constructor(private graph: Graph) { this.placeholders = graph.placeholders; this._outputs = graph.outputs; this.compile(); } get isControlFlowModel(): boolean { return this.graph.withControlFlow; } get isDynamicShapeModel(): boolean { return this.graph.withDynamicShape; } /** * Compiles the inference graph to generate the topology order of op nodes, * cache the result for inference execution. */ private compile(startNodes?: Node[]) { // Do not compile for graph with control flow, since the execution order // requires runtime evaluation of the output tensors. if (this.graph.withControlFlow || this.graph.withDynamicShape) { return; } const compiledOrder = []; const inputs = startNodes || this.graph.placeholders; const sortedNodeNames = inputs.map(node => node.name).sort(); const nameKey = sortedNodeNames.join(this.SEPERATOR); // do nothing is the compiled graph cache contains the input. if (this.compiledMap.get(nameKey)) { return; } const stack = [...inputs, ...this.graph.weights]; const visited: {[key: string]: boolean} = {}; while (stack.length > 0) { const node = stack.pop(); visited[node.name] = true; compiledOrder.push(node); node.children.forEach((childNode) => { if (!visited[childNode.name] && childNode.inputNames.every(name => { const [nodeName, ] = getNodeNameAndIndex(name); return visited[nodeName]; })) { stack.push(childNode); } }); } this.compiledMap.set(nameKey, compiledOrder); } /** * Executes the inference for given input tensors. * @param inputs Tensor map for the model inputs, keyed by the input node * names. * @param outputs output node name from the Tensorflow model, if no outputs * are specified, the default outputs of the model would be used. You can * inspect intermediate nodes of the model by adding them to the outputs * array. */ execute( inputs: NamedTensorsMap, strictInputCheck = true, outputs?: string|string[]): NamedTensorMap { const names = Object.keys(inputs).sort(); this.checkInput(inputs, strictInputCheck); this.checkInputShapeAndType(inputs, strictInputCheck); this.compile(names.map(name => this.graph.nodes[name])); const outputNames = this.calculateOutputs(outputs); this.checkOutput( this.compiledMap.get(names.join(this.SEPERATOR)), outputNames); const tensorArrayMap: TensorArrayMap = {}; const result = tidy(() => { const context = new ExecutionContext(this._weightMap, tensorArrayMap); const tensorMap = {...this.weightMap, ...inputs}; const tensorsToKeep = this.getFrozenTensorIds(tensorMap); const intermediateTensorConsumerCount: {[key: number]: number} = {}; const compiledNodes = this.compiledMap.get(names.join(this.SEPERATOR)); for (let i = 0; i < compiledNodes.length; i++) { const node = compiledNodes[i]; if (!tensorMap[node.name]) { tensorMap[node.name] = executeOp(node, tensorMap, context) as Tensor[]; this.checkTensorForDisposal( node.name, node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount); } // stop the execution if all outputs are found. if (outputNames.every(name => !!tensorMap[name])) { break; } } return this.findOutputs(tensorMap, context, outputNames); }); return result; } private getFrozenTensorIds(tensorMap: NamedTensorsMap): Set<number> { const ids = [].concat.apply( [], Object.keys(tensorMap) .map(key => tensorMap[key]) .map(tensors => tensors.map(tensor => tensor.id))); return new Set(ids); } private checkTensorForDisposal( nodeName: string, node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, tensorsToKeep: Set<number>, outputNames: string[], intermediateTensorConsumerCount: {[key: string]: number}) { // Skip output nodes and any control flow nodes, since its dependency is // tricky to track correctly. if (node.category === 'control' || outputNames.indexOf(nodeName) !== -1) { return; } tensorMap[nodeName].forEach(tensor => { if (tensor != null) { intermediateTensorConsumerCount[tensor.id] = (intermediateTensorConsumerCount[tensor.id] || 0) + node.children.length; } }); node.inputs.forEach(input => { // Skip any control flow nodes, since its dependency is tricky to track // correctly. if (input.category !== 'control') { const tensors = getTensorsForCurrentContenxt(input.name, tensorMap, context); if (tensors != null) { tensors.forEach(tensor => { if (tensor && !tensorsToKeep.has(tensor.id)) { const count = intermediateTensorConsumerCount[tensor.id]; if (count === 1) { tensor.dispose(); delete intermediateTensorConsumerCount[tensor.id]; } else if (count != null) { // only intermediate nodes has count set, inputs and weights are // not. intermediateTensorConsumerCount[tensor.id]--; } } }); } } }); } /** * Executes the inference for given input tensors in Async fashion. * @param inputs Tensor map for the model inputs, keyed by the input node * names. * @param outputs output node name from the Tensorflow model, if no outputs * are specified, the default outputs of the model would be used. You can * inspect intermediate nodes of the model by adding them to the outputs * array. */ async executeAsync(inputs: NamedTensorsMap, outputs?: string|string[]): Promise<NamedTensorMap> { this.checkInput(inputs, false); this.checkInputShapeAndType(inputs, false); const tensorArrayMap: TensorArrayMap = {}; const context = new ExecutionContext(this._weightMap, tensorArrayMap); const outputNames = this.calculateOutputs(outputs); // Graph with control flow op requires runtime evaluation of the execution // order, while without control flow the execution order is pre-determined // in the compile method. const tensors = await this.executeWithControlFlow(inputs, context, outputNames); const results = this.findOutputs(tensors, context, outputs); // dispose all the intermediate tensors const outputIds = Object.keys(results).map(key => results[key].id); const inputIdArray = Object.keys(inputs).map(key => inputs[key].map(input => input.id)); const inputIds = [].concat.apply([], inputIdArray); Object.keys(tensors).forEach(key => { const tensorArray = tensors[key]; tensorArray.forEach(tensor => { if (tensor && outputIds.indexOf(tensor.id) === -1 && inputIds.indexOf(tensor.id) === -1 && this.weightIds.indexOf(tensor.id) === -1) { tensor.dispose(); } }); }); return results; } /** * When there are control flow nodes in the graph, the graph execution use * ExecutionContext to keep track of the frames and loop iterators. * @param inputs placeholder tensors for the graph. * @param context the execution context object for current execution. */ private async executeWithControlFlow( inputs: NamedTensorsMap, context: ExecutionContext, outputNames: string[]): Promise<NamedTensorsMap> { const names = Object.keys(inputs); const inputNodes = names.map(name => this.graph.nodes[name]); const stack: NodeWithContexts[] = [...inputNodes, ...this.graph.weights].map(node => { return {node, contexts: context.currentContext}; }); const tensorMap = {...this.weightMap, ...inputs}; const intermediateTensorConsumerCount: {[key: number]: number} = {}; const tensorsToKeep = this.getFrozenTensorIds(tensorMap); const added: {[key: string]: boolean} = {}; while (stack.length > 0) { const promises = this.processStack( inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNames, intermediateTensorConsumerCount); await Promise.all(promises); } return tensorMap; } private processStack( inputNodes: Node[], stack: NodeWithContexts[], context: ExecutionContext, tensorMap: NamedTensorsMap, added: {[key: string]: boolean}, tensorsToKeep: Set<number>, outputNames: string[], intermediateTensorConsumerCount: {[key: number]: number}) { const promises: Array<Promise<Tensor[]>> = []; while (stack.length > 0) { const item = stack.pop(); context.currentContext = item.contexts; let nodeName = ''; // The tensor of the Enter op with isConstant set should be set // in the parent scope, so it will be available as constant for the // whole loop. if (item.node.op === 'enter' && getParamValue('isConstant', item.node, tensorMap, context)) { [nodeName] = getNodeNameAndIndex(item.node.name, context); } // only process nodes that are not provided as input nodes. if (inputNodes.indexOf(item.node) === -1) { const tensors = executeOp(item.node, tensorMap, context); if (!nodeName) { [nodeName] = getNodeNameAndIndex(item.node.name, context); } const currentContext = context.currentContext; if (tensors instanceof Promise) { promises.push(tensors.then(t => { tensorMap[nodeName] = t; context.currentContext = currentContext; this.checkTensorForDisposal( nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount); this.processChildNodes(item.node, stack, context, tensorMap, added); return t; })); } else { tensorMap[nodeName] = tensors; this.checkTensorForDisposal( nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount); this.processChildNodes(item.node, stack, context, tensorMap, added); } } else { this.processChildNodes(item.node, stack, context, tensorMap, added); } } return promises; } private processChildNodes( node: Node, stack: NodeWithContexts[], context: ExecutionContext, tensorMap: NamedTensorsMap, added: {[key: string]: boolean}) { node.children.forEach((childNode) => { const [nodeName, ] = getNodeNameAndIndex(childNode.name, context); if (!added[nodeName]) { // Merge op can be pushed if any of its inputs has value. if (childNode.op === 'merge') { if (childNode.inputNames.some(name => { return !!getTensor(name, tensorMap, context); })) { added[nodeName] = true; stack.push({contexts: context.currentContext, node: childNode}); } } else // Otherwise all inputs must to have value. if (childNode.inputNames.every(name => { return !!getTensor(name, tensorMap, context); })) { added[nodeName] = true; stack.push({contexts: context.currentContext, node: childNode}); } } }); } private calculateOutputs(outputs?: string|string[]): string[] { if (outputs && !(outputs instanceof Array)) { outputs = [outputs]; } return (outputs || this.graph.outputs.map(node => node.name)) as string[]; } private findOutputs( tensorMap: NamedTensorsMap, context: ExecutionContext, outputs?: string|string[]): NamedTensorMap { const requestedOutputs = this.calculateOutputs(outputs); return requestedOutputs.reduce<NamedTensorMap>((map, name) => { map[name] = getTensor(name, tensorMap, context); return map; }, {}); } /** * Releases the memory used by the weight tensors. */ dispose() { Object.keys(this.weightMap) .forEach( key => this.weightMap[key].forEach(tensor => tensor.dispose())); } private checkInputShapeAndType( inputs: NamedTensorsMap, strictInputCheck = true) { this.placeholders.forEach(node => { const inputTensors = inputs[node.name]; // do nothing if not strict input check and input tensors is not for // the placeholders. if (!strictInputCheck && !inputTensors) { return; } const input = inputTensors[0]; if (node.params['shape'] && node.params['shape'].value) { const shape = node.params['shape'].value as number[]; const match = shape.length === input.shape.length && input.shape.every( (dim, index) => shape[index] === -1 || shape[index] === dim); util.assert( match, `The shape of dict['${ node.name}'] provided in model.execute(dict) must be [${ shape}], but was [${input.shape}]`); } if (node.params['dtype'] && node.params['dtype'].value) { util.assert( input.dtype === node.params['dtype'].value as string, `The dtype of dict['${ node.name}'] provided in model.execute(dict) must be ${ node.params['dtype'].value}, but was ${input.dtype}`); } }); } private checkInput(inputs: NamedTensorsMap, strictInputCheck = true) { const inputKeys = Object.keys(inputs); const missing: string[] = []; const extra: string[] = []; this.inputNodes.forEach(name => { if (inputKeys.indexOf(name) === -1) missing.push(name); }); inputKeys.forEach(name => { if (this.inputNodes.indexOf(name) === -1) extra.push(name); }); const notInGraph = extra.filter(name => !this.graph.nodes[name]); if (missing.length > 0 && strictInputCheck) { throw new Error( `The dict provided in model.execute(dict) has the keys ` + `[${inputKeys}], but is missing the required keys: [${missing}].`); } if (extra.length > 0 && strictInputCheck) { throw new Error( `The dict provided in model.execute(dict) has ` + `unused keys: [${extra}]. Please provide only the following keys: ` + `[${this.inputNodes}].`); } if (notInGraph.length > 0) { throw new Error( `The dict provided in model.execute(dict) has ` + `keys: [${notInGraph}] not part of model graph.`); } } private checkOutput(compiledNodes: Node[], outputs: string[]) { const compiledNodeNames = compiledNodes.map(node => node.name); const extra: string[] = []; outputs.forEach(name => { const [nodeName] = parseNodeName(name); if (compiledNodeNames.indexOf(nodeName) === -1) extra.push(nodeName); }); if (extra.length > 0) { throw new Error( `The following outputs are not generated by the execution: ` + `[${extra}].`); } } }