@tensorflow-models/coco-ssd
Version:
Object detection model (coco-ssd) in TensorFlow.js
288 lines (267 loc) • 10.3 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {DataType} from '@tensorflow/tfjs-core';
import {tensorflow} from '../data/compiled_api';
import {getNodeNameAndIndex} from './executors/utils';
import * as arithmetic from './op_list/arithmetic';
import * as basicMath from './op_list/basic_math';
import * as control from './op_list/control';
import * as convolution from './op_list/convolution';
import * as creation from './op_list/creation';
import * as dynamic from './op_list/dynamic';
import * as evaluation from './op_list/evaluation';
import * as graph from './op_list/graph';
import * as image from './op_list/image';
import * as logical from './op_list/logical';
import * as matrices from './op_list/matrices';
import * as normalization from './op_list/normalization';
import * as reduction from './op_list/reduction';
import * as sliceJoin from './op_list/slice_join';
import * as spectral from './op_list/spectral';
import * as transformation from './op_list/transformation';
import {Graph, Node, OpMapper, ParamValue} from './types';
const CONTROL_FLOW_OPS = ['Switch', 'Merge', 'Enter', 'Exit', 'NextIteration'];
const DYNAMIC_SHAPE_OPS =
['NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'Where'];
export class OperationMapper {
private static _instance: OperationMapper;
private opMappers: {[key: string]: OpMapper};
// Singleton instance for the mapper
public static get Instance() {
return this._instance || (this._instance = new this());
}
// Loads the op mapping from the JSON file.
private constructor() {
const ops = [
arithmetic, basicMath, control, convolution, creation, dynamic,
evaluation, logical, image, graph, matrices, normalization, reduction,
sliceJoin, spectral, transformation
];
const mappersJson: OpMapper[] = [].concat.apply([], ops.map(op => op.json));
this.opMappers = mappersJson.reduce<{[key: string]: OpMapper}>(
(map, mapper: OpMapper) => {
map[mapper.tfOpName] = mapper;
return map;
},
{});
}
private isControlFlow(node: tensorflow.INodeDef) {
return CONTROL_FLOW_OPS.some(op => op === node.op);
}
private isDynamicShape(node: tensorflow.INodeDef) {
return DYNAMIC_SHAPE_OPS.some(op => op === node.op);
}
// Converts the model from Tensorflow GraphDef to local representation for
// deeplearn.js API
transformGraph(graph: tensorflow.IGraphDef): Graph {
const tfNodes = graph.node;
let withControlFlow = false;
let withDynamicShape = false;
const placeholders: Node[] = [];
const weights: Node[] = [];
const nodes = tfNodes.reduce<{[key: string]: Node}>((map, node) => {
map[node.name] = this.mapNode(node);
if (this.isControlFlow(node)) withControlFlow = true;
if (this.isDynamicShape(node)) withDynamicShape = true;
if (node.op === 'Placeholder') placeholders.push(map[node.name]);
if (node.op === 'Const') weights.push(map[node.name]);
return map;
}, {});
const inputs: Node[] = [];
const outputs: Node[] = [];
Object.keys(nodes).forEach(key => {
const node = nodes[key];
node.inputNames.forEach(name => {
const [nodeName, ] = getNodeNameAndIndex(name);
node.inputs.push(nodes[nodeName]);
nodes[nodeName].children.push(node);
});
if (node.inputs.length === 0) inputs.push(node);
});
Object.keys(nodes).forEach(key => {
const node = nodes[key];
if (node.children.length === 0) outputs.push(node);
});
return {
nodes,
inputs,
outputs,
weights,
placeholders,
withControlFlow,
withDynamicShape
};
}
private mapNode(node: tensorflow.INodeDef): Node {
const mapper = this.opMappers[node.op];
if (mapper === undefined) {
throw new Error('Tensorflow Op is not supported: ' + node.op);
}
const newNode: Node = {
name: node.name,
op: mapper.dlOpName,
category: mapper.category,
inputNames:
(node.input ||
[]).map(input => input.startsWith('^') ? input.substr(1) : input),
inputs: [],
children: [],
params: {}
};
if (!!mapper.params) {
newNode.params = mapper.params.reduce<{[key: string]:
ParamValue}>((map, param) => {
const inputIndex = param.tfInputIndex;
const inputParamLength = param.tfInputParamLength;
const type = param.type;
let value = undefined;
if (inputIndex === undefined) {
switch (param.type) {
case 'string':
value = this.getStringParam(
node.attr, param.tfParamName, param.defaultValue as string);
if (value === undefined && !!param.tfParamNameDeprecated) {
value = this.getStringParam(
node.attr, param.tfParamNameDeprecated,
param.defaultValue as string);
}
break;
case 'number':
value = this.getNumberParam(
node.attr, param.tfParamName,
(param.defaultValue || 0) as number);
if (value === undefined && !!param.tfParamNameDeprecated) {
value = this.getNumberParam(
node.attr, param.tfParamNameDeprecated,
param.defaultValue as number);
}
break;
case 'number[]':
value = this.getNumericArrayParam(
node.attr, param.tfParamName, param.defaultValue as number[]);
if (value === undefined && !!param.tfParamNameDeprecated) {
value = this.getNumericArrayParam(
node.attr, param.tfParamNameDeprecated,
param.defaultValue as number[]);
}
break;
case 'bool':
value = this.getBoolParam(
node.attr, param.tfParamName, param.defaultValue as boolean);
if (value === undefined && !!param.tfParamNameDeprecated) {
value = this.getBoolParam(
node.attr, param.tfParamNameDeprecated,
param.defaultValue as boolean);
}
break;
case 'shape':
value = this.getTensorShapeParam(
node.attr, param.tfParamName, param.defaultValue as number[]);
if (value === undefined && !!param.tfParamNameDeprecated) {
value = this.getTensorShapeParam(
node.attr, param.tfParamNameDeprecated,
param.defaultValue as number[]);
}
break;
case 'dtype':
value = this.getDtypeParam(
node.attr, param.tfParamName, param.defaultValue as DataType);
if (value === undefined && !!param.tfParamNameDeprecated) {
value = this.getDtypeParam(
node.attr, param.tfParamNameDeprecated,
param.defaultValue as DataType);
}
break;
case 'tensor':
case 'tensors':
break;
default:
throw new Error(
`Unsupported param type: ${param.type} for op: ${node.op}`);
}
}
map[param.dlParamName] = {value, inputIndex, type, inputParamLength};
return map;
}, {});
}
return newNode;
}
private getStringParam(
attrs: {[key: string]: tensorflow.IAttrValue}, name: string, def: string,
keepCase = false): string {
const param = attrs[name];
if (param !== undefined) {
const value = String.fromCharCode.apply(null, param.s);
return keepCase ? value : value.toLowerCase();
}
return def;
}
private getBoolParam(
attrs: {[key: string]: tensorflow.IAttrValue}, name: string,
def: boolean): boolean {
const param = attrs[name];
return param ? param.b : def;
}
private getNumberParam(
attrs: {[key: string]: tensorflow.IAttrValue}, name: string,
def: number): number {
const param = attrs[name] as tensorflow.AttrValue;
const value = (param ? param[param.value] : def) as number | Long;
return (typeof value === 'number') ? value : value['toInt']() as number;
}
private getDtypeParam(
attrs: {[key: string]: tensorflow.IAttrValue}, name: string,
def: DataType): DataType {
const param = attrs[name];
if (param && param.type) {
switch (param.type) {
case tensorflow.DataType.DT_FLOAT:
return 'float32';
case tensorflow.DataType.DT_INT32:
return 'int32';
case tensorflow.DataType.DT_BOOL:
return 'bool';
default:
return def;
}
}
return def;
}
private getTensorShapeParam(
attrs: {[key: string]: tensorflow.IAttrValue}, name: string,
def?: number[]): number[]|undefined {
const param = attrs[name];
if (param && param.shape) {
return param.shape.dim.map(
dim =>
(typeof dim.size === 'number') ? dim.size : dim.size['toInt']());
}
return def;
}
private getNumericArrayParam(
attrs: {[key: string]: tensorflow.IAttrValue}, name: string,
def: number[]): number[] {
const param = attrs[name];
if (param) {
return ((param.list.f && param.list.f.length ? param.list.f :
param.list.i))
.map(v => (typeof v === 'number') ? v : v['toInt']()) as
number[];
}
return def;
}
}