UNPKG

@tensorflow-models/coco-ssd

Version:

Object detection model (coco-ssd) in TensorFlow.js

116 lines (102 loc) 3.97 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tfc from '@tensorflow/tfjs-core'; import {NamedTensorsMap} from '../../data/types'; import {ExecutionContext} from '../../executor/execution_context'; import {Node, ValueType} from '../types'; export function getParamValue( paramName: string, node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): ValueType { const param = node.params[paramName]; if (param && param.inputIndex !== undefined) { if (param.type === 'tensor') { return getTensor(node.inputNames[param.inputIndex], tensorMap, context); } if (param.type === 'tensors') { const inputs = param.inputIndex === 0 ? (param.inputParamLength === 0 ? node.inputNames : node.inputNames.slice( param.inputIndex, -param.inputParamLength)) : node.inputNames.splice(param.inputIndex); return inputs.map(name => getTensor(name, tensorMap, context)); } const data = Array.prototype.slice.call( getTensor( node.inputNames.slice(param.inputIndex)[0], tensorMap, context) .dataSync()); return param.type === 'number' ? data[0] : data; } return param && param.value; } /** * Retrieve the tensor based on input name by extracting the node name and * output index information. * @param name Node input name * @param tensorsMap Tensors map keyed by the node */ export function getTensor( name: string, tensorsMap: NamedTensorsMap, context: ExecutionContext): tfc.Tensor { const [nodeName, index] = parseNodeName(name); const contextId = context.currentContextIds.find(contextId => { return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId)]; }); return contextId !== undefined ? tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] : undefined; } /** * Retrieve the tensors based on input name for current context. * @param name Node input name * @param tensorsMap Tensors map keyed by the node */ export function getTensorsForCurrentContenxt( name: string, tensorsMap: NamedTensorsMap, context: ExecutionContext): tfc.Tensor[] { return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)]; } /** * Returns the node name and index from the Node input name. * @param inputName The input name of the node, in format of * node_name:output_index, i.e. MatMul:0, if the output_index is not set, it is * default to 0. */ export function getNodeNameAndIndex( inputName: string, context?: ExecutionContext): [string, number] { const [nodeName, index] = parseNodeName(inputName); return [ getNodeNameWithContextId(nodeName, context && context.currentContextId), index ]; } function getNodeNameWithContextId(name: string, contextId?: string): string { return !!contextId ? `${name}-${contextId}` : name; } export function parseNodeName(name: string): [string, number] { const index = name.lastIndexOf(':'); if (index === -1) return [name, 0]; const nodeName = name.substring(0, index); return [nodeName, Number(name.substring(index + 1))]; } export function split(arr: number[], size: number) { const res = []; for (let i = 0; i < arr.length; i += size) { res.push(arr.slice(i, i + size)); } return res; }