@tensorflow/tfjs
Version:
An open-source machine learning framework.
276 lines (268 loc) • 8.49 kB
text/typescript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line: no-imports-from-dist
import * as tensorflow from '@tensorflow/tfjs-converter/dist/data/compiled_api';
import {io} from '@tensorflow/tfjs-core';
import * as fs from 'fs';
import {getOps} from './model_parser';
const SIMPLE_MODEL: io.ModelArtifacts = {
format: 'graph-model',
generatedBy: '0.0.0',
convertedBy: 'Test Data',
modelTopology: {
node: [
{
name: 'Input',
op: 'Placeholder',
attr: {
dtype: {
type: tensorflow.DataType.DT_INT32,
},
shape: {shape: {dim: [{size: -1}, {size: 1}]}}
}
},
{name: 'Add1', op: 'Add', input: ['Input', 'Const'], attr: {}},
{name: 'Sub', op: 'Sub', input: ['Add1', 'Input'], attr: {}}
],
versions: {producer: 1.0, minConsumer: 3}
}
};
const CONTROL_FLOW_V2_MODEL: io.ModelArtifacts = {
format: 'graph-model',
generatedBy: '0.0.0',
convertedBy: 'Test Data',
modelTopology: {
node: [
{
name: 'image_placeholder',
op: 'Placeholder',
attr: {
dtype: {
type: tensorflow.DataType.DT_FLOAT,
},
shape:
{shape: {dim: [{size: '3'}, {size: 3}, {size: '3'}, {size: 1}]}}
}
},
{
name: 'Const',
op: 'Const',
attr: {
dtype: {type: tensorflow.DataType.DT_INT32},
value: {
tensor: {
dtype: tensorflow.DataType.DT_INT32,
tensorShape: {dim: [{size: 3}, {size: 3}, {size: 1}, {size: 1}]},
intVal: [0, 0, 0, 0, 1, 0, 0, 0, 0]
}
}
}
},
{
name: 'Shape',
op: 'Const',
attr: {
dtype: {type: tensorflow.DataType.DT_INT32},
value: {
tensor: {
dtype: tensorflow.DataType.DT_INT32,
tensorShape: {dim: [{size: 3}, {size: 1}, {size: 1}, {size: 1}]},
intVal: [1, 1, 1]
}
}
}
},
{
name: 'Value',
op: 'Const',
attr: {dtype: {type: tensorflow.DataType.DT_INT32}, value: {i: 1}}
},
{name: 'Fill', op: 'Fill', input: ['Shape', 'Value'], attr: {}},
{
name: 'Conv2D',
op: 'Conv2D',
input: ['image_placeholder', 'Const'],
attr: {
T: {type: tensorflow.DataType.DT_FLOAT},
dataFormat: {s: 'TkhXQw=='},
padding: {s: 'U0FNRQ=='},
strides: {list: {f: [], i: [1, 2, 2, 1]}},
useCudnnOnGpu: {b: true}
}
},
{
name: 'BiasAdd',
op: 'BiasAdd',
input: ['Conv2D', 'Shape'],
attr: {
T: {type: tensorflow.DataType.DT_FLOAT},
dataFormat: {s: 'TkhXQw=='}
}
},
{
name: 'Cast',
op: 'Cast',
input: ['BiasAdd'],
attr: {DstT: {type: tensorflow.DataType.DT_INT64}}
},
{
name: 'Squeeze',
op: 'Squeeze',
input: ['Cast'],
attr: {squeeze_dims: {list: {i: ['1', '2']}}}
},
{
name: 'Squeeze2',
op: 'Squeeze',
input: ['BiasAdd'],
attr: {squeeze_dims: {list: {}}}
},
{
name: 'Split',
op: 'Split',
input: ['image_placeholder'],
attr: {num_split: {i: 3} as tensorflow.IAttrValue}
},
{name: 'LogicalNot', op: 'LogicalNot', input: ['image_placeholder']},
{
name: 'FusedBatchNorm',
op: 'FusedBatchNorm',
input: ['image_placeholder'],
attr: {epsilon: {f: 0.0001} as tensorflow.IAttrValue}
},
{
name: 'Cast2',
op: 'Cast',
input: ['BiasAdd'],
attr: {DstT: {type: tensorflow.DataType.DT_UINT8}}
},
],
library: {
function: [
{
signature: {
name: '__inference_while_cond_10_49_frozen',
inputArg: [
{name: 'while_loop_counter', type: tensorflow.DataType.DT_INT32},
{
name: 'while_maximum_iterations',
type: tensorflow.DataType.DT_INT32
},
{name: 'placeholder', type: tensorflow.DataType.DT_INT32},
{name: 'less_y', type: tensorflow.DataType.DT_INT32}, {
name: 'while_cond_10___redundant_placeholder0',
type: tensorflow.DataType.DT_INT32
}
],
outputArg: [{name: 'identity', type: tensorflow.DataType.DT_BOOL}]
},
nodeDef: [{
name: 'Less',
op: 'Less',
input: ['placeholder', 'less_y'],
attr: {T: {type: tensorflow.DataType.DT_INT32}}
}],
ret: {identity: 'Less:z:0'}
},
{
signature: {
name: '__inference_while_body_11_40_frozen',
inputArg: [
{name: 'while_loop_counter', type: tensorflow.DataType.DT_INT32},
{
name: 'while_maximum_iterations',
type: tensorflow.DataType.DT_INT32
},
{name: 'placeholder', type: tensorflow.DataType.DT_INT32},
{name: 'y_0', type: tensorflow.DataType.DT_INT32},
{name: 'add_1_z_0', type: tensorflow.DataType.DT_INT32}
],
outputArg: [
{name: 'identity', type: tensorflow.DataType.DT_INT32},
{name: 'identity_1', type: tensorflow.DataType.DT_INT32},
{name: 'identity_2', type: tensorflow.DataType.DT_INT32},
{name: 'y', type: tensorflow.DataType.DT_INT32},
{name: 'add_1_z', type: tensorflow.DataType.DT_INT32}
]
},
nodeDef: [
{
name: 'add_2/y',
op: 'Const',
attr: {
dtype: {type: tensorflow.DataType.DT_INT32},
value: {
tensor: {
dtype: tensorflow.DataType.DT_INT32,
tensorShape: {},
intVal: [1]
}
}
}
},
{
name: 'add',
op: 'AddV2',
input: ['placeholder', 'y_0'],
attr: {T: {type: tensorflow.DataType.DT_INT32}}
},
{
name: 'add_2',
op: 'AddV2',
input: ['add_2/y:output:0', 'while_loop_counter'],
attr: {T: {type: tensorflow.DataType.DT_INT32}}
},
{
name: 'add_1',
op: 'AddV2',
input: ['add:z:0', 'add_1_z_0'],
attr: {T: {type: tensorflow.DataType.DT_INT32}}
}
],
ret: {
identity_1: 'while_maximum_iterations',
identity_2: 'add_1:z:0',
y: 'y_0',
identity: 'add_2:z:0',
add_1_z: 'add_1_z_0'
}
}
]
},
versions: {producer: 1.0}
}
};
describe('Model parse', () => {
// tslint:disable-next-line: no-any
let kernelToOps: any;
beforeAll(() => {
const mappingPath =
require.resolve('@tensorflow/tfjs-converter/metadata/kernel2op.json');
kernelToOps = JSON.parse(fs.readFileSync(mappingPath, 'utf-8'));
});
it('should get ops from simple model', () => {
const ops = getOps(SIMPLE_MODEL, kernelToOps);
expect(ops).toEqual(jasmine.arrayWithExactContents(['add', 'sub']));
});
it('should get ops from control flow v2 model', () => {
const ops = getOps(CONTROL_FLOW_V2_MODEL, kernelToOps);
expect(ops).toEqual(jasmine.arrayWithExactContents([
'cast', 'batchNorm', 'logicalNot', 'split', 'squeeze', 'add', 'conv2d',
'fill', 'less'
]));
});
});