UNPKG

@tensorflow/tfjs

Version:

An open-source machine learning framework.

276 lines (268 loc) 8.49 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line: no-imports-from-dist import * as tensorflow from '@tensorflow/tfjs-converter/dist/data/compiled_api'; import {io} from '@tensorflow/tfjs-core'; import * as fs from 'fs'; import {getOps} from './model_parser'; const SIMPLE_MODEL: io.ModelArtifacts = { format: 'graph-model', generatedBy: '0.0.0', convertedBy: 'Test Data', modelTopology: { node: [ { name: 'Input', op: 'Placeholder', attr: { dtype: { type: tensorflow.DataType.DT_INT32, }, shape: {shape: {dim: [{size: -1}, {size: 1}]}} } }, {name: 'Add1', op: 'Add', input: ['Input', 'Const'], attr: {}}, {name: 'Sub', op: 'Sub', input: ['Add1', 'Input'], attr: {}} ], versions: {producer: 1.0, minConsumer: 3} } }; const CONTROL_FLOW_V2_MODEL: io.ModelArtifacts = { format: 'graph-model', generatedBy: '0.0.0', convertedBy: 'Test Data', modelTopology: { node: [ { name: 'image_placeholder', op: 'Placeholder', attr: { dtype: { type: tensorflow.DataType.DT_FLOAT, }, shape: {shape: {dim: [{size: '3'}, {size: 3}, {size: '3'}, {size: 1}]}} } }, { name: 'Const', op: 'Const', attr: { dtype: {type: tensorflow.DataType.DT_INT32}, value: { tensor: { dtype: tensorflow.DataType.DT_INT32, tensorShape: {dim: [{size: 3}, {size: 3}, {size: 1}, {size: 1}]}, intVal: [0, 0, 0, 0, 1, 0, 0, 0, 0] } } } }, { name: 'Shape', op: 'Const', attr: { dtype: {type: tensorflow.DataType.DT_INT32}, value: { tensor: { dtype: tensorflow.DataType.DT_INT32, tensorShape: {dim: [{size: 3}, {size: 1}, {size: 1}, {size: 1}]}, intVal: [1, 1, 1] } } } }, { name: 'Value', op: 'Const', attr: {dtype: {type: tensorflow.DataType.DT_INT32}, value: {i: 1}} }, {name: 'Fill', op: 'Fill', input: ['Shape', 'Value'], attr: {}}, { name: 'Conv2D', op: 'Conv2D', input: ['image_placeholder', 'Const'], attr: { T: {type: tensorflow.DataType.DT_FLOAT}, dataFormat: {s: 'TkhXQw=='}, padding: {s: 'U0FNRQ=='}, strides: {list: {f: [], i: [1, 2, 2, 1]}}, useCudnnOnGpu: {b: true} } }, { name: 'BiasAdd', op: 'BiasAdd', input: ['Conv2D', 'Shape'], attr: { T: {type: tensorflow.DataType.DT_FLOAT}, dataFormat: {s: 'TkhXQw=='} } }, { name: 'Cast', op: 'Cast', input: ['BiasAdd'], attr: {DstT: {type: tensorflow.DataType.DT_INT64}} }, { name: 'Squeeze', op: 'Squeeze', input: ['Cast'], attr: {squeeze_dims: {list: {i: ['1', '2']}}} }, { name: 'Squeeze2', op: 'Squeeze', input: ['BiasAdd'], attr: {squeeze_dims: {list: {}}} }, { name: 'Split', op: 'Split', input: ['image_placeholder'], attr: {num_split: {i: 3} as tensorflow.IAttrValue} }, {name: 'LogicalNot', op: 'LogicalNot', input: ['image_placeholder']}, { name: 'FusedBatchNorm', op: 'FusedBatchNorm', input: ['image_placeholder'], attr: {epsilon: {f: 0.0001} as tensorflow.IAttrValue} }, { name: 'Cast2', op: 'Cast', input: ['BiasAdd'], attr: {DstT: {type: tensorflow.DataType.DT_UINT8}} }, ], library: { function: [ { signature: { name: '__inference_while_cond_10_49_frozen', inputArg: [ {name: 'while_loop_counter', type: tensorflow.DataType.DT_INT32}, { name: 'while_maximum_iterations', type: tensorflow.DataType.DT_INT32 }, {name: 'placeholder', type: tensorflow.DataType.DT_INT32}, {name: 'less_y', type: tensorflow.DataType.DT_INT32}, { name: 'while_cond_10___redundant_placeholder0', type: tensorflow.DataType.DT_INT32 } ], outputArg: [{name: 'identity', type: tensorflow.DataType.DT_BOOL}] }, nodeDef: [{ name: 'Less', op: 'Less', input: ['placeholder', 'less_y'], attr: {T: {type: tensorflow.DataType.DT_INT32}} }], ret: {identity: 'Less:z:0'} }, { signature: { name: '__inference_while_body_11_40_frozen', inputArg: [ {name: 'while_loop_counter', type: tensorflow.DataType.DT_INT32}, { name: 'while_maximum_iterations', type: tensorflow.DataType.DT_INT32 }, {name: 'placeholder', type: tensorflow.DataType.DT_INT32}, {name: 'y_0', type: tensorflow.DataType.DT_INT32}, {name: 'add_1_z_0', type: tensorflow.DataType.DT_INT32} ], outputArg: [ {name: 'identity', type: tensorflow.DataType.DT_INT32}, {name: 'identity_1', type: tensorflow.DataType.DT_INT32}, {name: 'identity_2', type: tensorflow.DataType.DT_INT32}, {name: 'y', type: tensorflow.DataType.DT_INT32}, {name: 'add_1_z', type: tensorflow.DataType.DT_INT32} ] }, nodeDef: [ { name: 'add_2/y', op: 'Const', attr: { dtype: {type: tensorflow.DataType.DT_INT32}, value: { tensor: { dtype: tensorflow.DataType.DT_INT32, tensorShape: {}, intVal: [1] } } } }, { name: 'add', op: 'AddV2', input: ['placeholder', 'y_0'], attr: {T: {type: tensorflow.DataType.DT_INT32}} }, { name: 'add_2', op: 'AddV2', input: ['add_2/y:output:0', 'while_loop_counter'], attr: {T: {type: tensorflow.DataType.DT_INT32}} }, { name: 'add_1', op: 'AddV2', input: ['add:z:0', 'add_1_z_0'], attr: {T: {type: tensorflow.DataType.DT_INT32}} } ], ret: { identity_1: 'while_maximum_iterations', identity_2: 'add_1:z:0', y: 'y_0', identity: 'add_2:z:0', add_1_z: 'add_1_z_0' } } ] }, versions: {producer: 1.0} } }; describe('Model parse', () => { // tslint:disable-next-line: no-any let kernelToOps: any; beforeAll(() => { const mappingPath = require.resolve('@tensorflow/tfjs-converter/metadata/kernel2op.json'); kernelToOps = JSON.parse(fs.readFileSync(mappingPath, 'utf-8')); }); it('should get ops from simple model', () => { const ops = getOps(SIMPLE_MODEL, kernelToOps); expect(ops).toEqual(jasmine.arrayWithExactContents(['add', 'sub'])); }); it('should get ops from control flow v2 model', () => { const ops = getOps(CONTROL_FLOW_V2_MODEL, kernelToOps); expect(ops).toEqual(jasmine.arrayWithExactContents([ 'cast', 'batchNorm', 'logicalNot', 'split', 'squeeze', 'add', 'conv2d', 'fill', 'less' ])); }); });