@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
48 lines • 12 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
import { getParamValue } from './utils';
export const executeOp = (node, tensorMap, context) => {
switch (node.op) {
case 'EuclideanNorm':
return [tfOps.euclideanNorm(getParamValue('x', node, tensorMap, context), getParamValue('axis', node, tensorMap, context), getParamValue('keepDims', node, tensorMap, context))];
case 'FusedBatchNorm':
case 'FusedBatchNormV2': {
return [tfOps.batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
}
case 'FusedBatchNormV3': {
return [tfOps.batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
}
case 'LRN': {
return [tfOps.localResponseNormalization(getParamValue('x', node, tensorMap, context), getParamValue('radius', node, tensorMap, context), getParamValue('bias', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context), getParamValue('beta', node, tensorMap, context))];
}
case 'Softmax': {
return [tfOps.softmax(getParamValue('x', node, tensorMap, context))];
}
case 'LogSoftmax': {
return [tfOps.logSoftmax(getParamValue('x', node, tensorMap, context))];
}
case 'SparseToDense': {
return [tfOps.sparseToDense(getParamValue('sparseIndices', node, tensorMap, context), getParamValue('outputShape', node, tensorMap, context), getParamValue('sparseValues', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context))];
}
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};
export const CATEGORY = 'normalization';
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"normalization_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/normalization_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAY,EAAE;IACtC,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,eAAe;YAClB,OAAO,CAAC,KAAK,CAAC,aAAa,CACvB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,EAC3D,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,CAAC,CAAC;QACvE,KAAK,gBAAgB,CAAC;QACtB,KAAK,kBAAkB,CAAC,CAAC;YACvB,OAAO,CAAC,KAAK,CAAC,SAAS,CACnB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACzD,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC7D,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC1D,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,kBAAkB,CAAC,CAAC;YACvB,OAAO,CAAC,KAAK,CAAC,SAAS,CACnB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACzD,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC7D,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC1D,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,KAAK,CAAC,CAAC;YACV,OAAO,CAAC,KAAK,CAAC,0BAA0B,CACpC,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/B,EACZ,aAAa,CAAC,QAAQ,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC3D,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACzD,aAAa,CAAC,OAAO,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAC1D,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SACjE;QACD,KAAK,SAAS,CAAC,CAAC;YACd,OAAO,CAAC,KAAK,CAAC,OAAO,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,YAAY,CAAC,CAAC;YACjB,OAAO,CAAC,KAAK,CAAC,UAAU,CACpB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,eAAe,CAAC,CAAC;YACpB,OAAO,CAAC,KAAK,CAAC,aAAa,CACvB,aAAa,CAAC,eAAe,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC7C,EACV,aAAa,CAAC,aAAa,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAChE,aAAa,CAAC,cAAc,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC1C,EACZ,aAAa,CAAC,cAAc,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC5C,CAAC,CAAC,CAAC;SAClB;QACD;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,eAAe,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Scalar, Tensor, Tensor3D, Tensor4D} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getParamValue} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext): Tensor[] => {\n      switch (node.op) {\n        case 'EuclideanNorm':\n          return [tfOps.euclideanNorm(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('axis', node, tensorMap, context) as number[],\n              getParamValue('keepDims', node, tensorMap, context) as boolean)];\n        case 'FusedBatchNorm':\n        case 'FusedBatchNormV2': {\n          return [tfOps.batchNorm(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('mean', node, tensorMap, context) as Tensor,\n              getParamValue('variance', node, tensorMap, context) as Tensor,\n              getParamValue('offset', node, tensorMap, context) as Tensor,\n              getParamValue('scale', node, tensorMap, context) as Tensor,\n              getParamValue('epsilon', node, tensorMap, context) as number)];\n        }\n        case 'FusedBatchNormV3': {\n          return [tfOps.batchNorm(\n              getParamValue('x', node, tensorMap, context) as Tensor,\n              getParamValue('mean', node, tensorMap, context) as Tensor,\n              getParamValue('variance', node, tensorMap, context) as Tensor,\n              getParamValue('offset', node, tensorMap, context) as Tensor,\n              getParamValue('scale', node, tensorMap, context) as Tensor,\n              getParamValue('epsilon', node, tensorMap, context) as number)];\n        }\n        case 'LRN': {\n          return [tfOps.localResponseNormalization(\n              getParamValue('x', node, tensorMap, context) as Tensor3D |\n                  Tensor4D,\n              getParamValue('radius', node, tensorMap, context) as number,\n              getParamValue('bias', node, tensorMap, context) as number,\n              getParamValue('alpha', node, tensorMap, context) as number,\n              getParamValue('beta', node, tensorMap, context) as number)];\n        }\n        case 'Softmax': {\n          return [tfOps.softmax(\n              getParamValue('x', node, tensorMap, context) as Tensor)];\n        }\n        case 'LogSoftmax': {\n          return [tfOps.logSoftmax(\n              getParamValue('x', node, tensorMap, context) as Tensor)];\n        }\n        case 'SparseToDense': {\n          return [tfOps.sparseToDense(\n              getParamValue('sparseIndices', node, tensorMap, context) as\n                  Tensor,\n              getParamValue('outputShape', node, tensorMap, context) as Tensor,\n              getParamValue('sparseValues', node, tensorMap, context) as\n                  number[],\n              getParamValue('defaultValue', node, tensorMap, context) as\n                  Scalar)];\n        }\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'normalization';\n"]}