@tensorflow-models/coco-ssd
Version:
Object detection model (coco-ssd) in TensorFlow.js
230 lines (216 loc) • 7.58 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {tensorflow_json} from '../data/compiled_api_json';
import * as arithmetic from './op_list/arithmetic';
import * as basicMath from './op_list/basic_math';
import * as control from './op_list/control';
import * as convolution from './op_list/convolution';
import * as creation from './op_list/creation';
import * as dynamic from './op_list/dynamic';
import * as evaluation from './op_list/evaluation';
import * as graph from './op_list/graph';
import * as image from './op_list/image';
import * as logical from './op_list/logical';
import * as matrices from './op_list/matrices';
import * as normalization from './op_list/normalization';
import * as reduction from './op_list/reduction';
import * as sliceJoin from './op_list/slice_join';
import * as spectral from './op_list/spectral';
import * as transformation from './op_list/transformation';
import {OperationMapper} from './operation_mapper_json';
import {Graph, OpMapper} from './types';
const ops = [
arithmetic, basicMath, control, convolution, creation, dynamic, evaluation,
logical, image, graph, matrices, normalization, reduction, sliceJoin,
spectral, transformation
];
const mapper: OperationMapper = OperationMapper.Instance;
let convertedGraph: Graph;
const SIMPLE_MODEL: tensorflow_json.IGraphDef = {
node: [
{
name: 'image_placeholder',
op: 'Placeholder',
attr: {
dtype: {
type: tensorflow_json.DataType.DT_FLOAT,
},
shape:
{shape: {dim: [{size: '3'}, {size: 3}, {size: '3'}, {size: 1}]}}
}
},
{
name: 'Const',
op: 'Const',
attr: {
dtype: {type: tensorflow_json.DataType.DT_INT32},
value: {
tensor: {
dtype: tensorflow_json.DataType.DT_INT32,
tensorShape: {dim: [{size: 3}, {size: 3}, {size: 1}, {size: 1}]},
intVal: [0, 0, 0, 0, 1, 0, 0, 0, 0]
}
}
}
},
{
name: 'Shape',
op: 'Const',
attr: {
dtype: {type: tensorflow_json.DataType.DT_INT32},
value: {
tensor: {
dtype: tensorflow_json.DataType.DT_INT32,
tensorShape: {dim: [{size: 3}, {size: 1}, {size: 1}, {size: 1}]},
intVal: [1, 1, 1]
}
}
}
},
{
name: 'Value',
op: 'Const',
attr: {dtype: {type: tensorflow_json.DataType.DT_INT32}, value: {i: 1}}
},
{name: 'Fill', op: 'Fill', input: ['Shape', 'Value'], attr: {}}, {
name: 'Conv2D',
op: 'Conv2D',
input: ['image_placeholder', 'Const'],
attr: {
T: {type: tensorflow_json.DataType.DT_FLOAT},
dataFormat: {s: 'TkhXQw=='},
padding: {s: 'U0FNRQ=='},
strides: {list: {f: [], i: [1, 2, 2, 1]}},
useCudnnOnGpu: {b: true}
}
},
{
name: 'BiasAdd',
op: 'BiasAdd',
input: ['Conv2D', 'Shape'],
attr: {
T: {type: tensorflow_json.DataType.DT_FLOAT},
dataFormat: {s: 'TkhXQw=='}
}
},
{
name: 'Squeeze',
op: 'Squeeze',
input: ['BiasAdd'],
attr: {squeeze_dims: {list: {i: ['1', '2']}}}
},
{
name: 'Split',
op: 'Split',
input: ['image_placeholder'],
attr: {num_split: {i: 3} as tensorflow_json.IAttrValue}
},
{
name: 'FusedBatchNorm',
op: 'FusedBatchNorm',
input: ['image_placeholder'],
attr: {epsilon: {f: 0.0001} as tensorflow_json.IAttrValue}
}
],
versions: {producer: 1.0}
};
describe('completeness check', () => {
it('should convert all op categories', () => {
ops.forEach(op => {
(op.json as OpMapper[]).forEach(tfOp => {
const graph = {
node: [{name: tfOp.tfOpName, op: tfOp.tfOpName, attr: {}}]
};
convertedGraph = mapper.transformGraph(graph);
expect(Object.keys(convertedGraph.nodes)).toEqual([tfOp.tfOpName]);
expect(convertedGraph.nodes[tfOp.tfOpName].op).toEqual(tfOp.dlOpName);
});
});
});
});
describe('operationMapper', () => {
beforeEach(() => {
convertedGraph = mapper.transformGraph(SIMPLE_MODEL);
});
afterEach(() => {});
describe('transform graph', () => {
describe('graph level', () => {
it('should find the graph input nodes', () => {
expect(convertedGraph.inputs.map(node => node.name)).toEqual([
'image_placeholder', 'Const', 'Shape', 'Value'
]);
});
it('should find the graph output nodes', () => {
expect(convertedGraph.outputs.map(node => node.name)).toEqual([
'Fill', 'Squeeze', 'Split', 'FusedBatchNorm'
]);
});
it('should find the graph weight nodes', () => {
expect(convertedGraph.weights.map(node => node.name)).toEqual([
'Const', 'Shape', 'Value'
]);
});
it('should convert nodes', () => {
expect(Object.keys(convertedGraph.nodes)).toEqual([
'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D',
'BiasAdd', 'Squeeze', 'Split', 'FusedBatchNorm'
]);
});
});
describe('node level', () => {
it('should find the input nodes', () => {
expect(convertedGraph.nodes['Fill'].inputs.map(node => node.name))
.toEqual(['Shape', 'Value']);
});
it('should find the children nodes', () => {
expect(convertedGraph.nodes['image_placeholder'].children.map(
node => node.name))
.toEqual(['Conv2D', 'Split', 'FusedBatchNorm']);
});
it('should map the input params', () => {
expect(convertedGraph.nodes['Fill'].params['shape'].inputIndex)
.toEqual(0);
expect(convertedGraph.nodes['Fill'].params['value'].inputIndex)
.toEqual(1);
});
it('should map the attribute params', () => {
expect(convertedGraph.nodes['Conv2D'].params['strides'].value).toEqual([
1, 2, 2, 1
]);
expect(convertedGraph.nodes['Conv2D'].params['pad'].value)
.toEqual('same');
expect(convertedGraph.nodes['Conv2D'].params['useCudnnOnGpu'].value)
.toEqual(true);
expect(convertedGraph.nodes['Split'].params['numOrSizeSplits'].value)
.toEqual(3);
expect(convertedGraph.nodes['FusedBatchNorm'].params['epsilon'].value)
.toEqual(0.0001);
});
it('should map the placeholder attribute params', () => {
expect(convertedGraph.nodes['image_placeholder'].params['shape'].value)
.toEqual([3, 3, 3, 1]);
expect(convertedGraph.nodes['image_placeholder'].params['dtype'].value)
.toEqual('float32');
});
it('should map params with deprecated name', () => {
expect(convertedGraph.nodes['Squeeze'].params['axis'].value).toEqual([
1, 2
]);
});
});
});
});