@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
94 lines • 19.9 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
import { getParamValue } from './utils';
export const executeOp = (node, tensorMap, context) => {
switch (node.op) {
case 'Max': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.max(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'Mean': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.mean(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'Min': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.min(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'Sum': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.sum(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'All': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.all(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'Any': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.any(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'ArgMax': {
const axis = getParamValue('axis', node, tensorMap, context);
return [tfOps.argMax(getParamValue('x', node, tensorMap, context), axis)];
}
case 'ArgMin': {
const axis = getParamValue('axis', node, tensorMap, context);
return [tfOps.argMin(getParamValue('x', node, tensorMap, context), axis)];
}
case 'Prod': {
const axis = getParamValue('axis', node, tensorMap, context);
const keepDims = getParamValue('keepDims', node, tensorMap, context);
return [tfOps.prod(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'Cumprod': {
const axis = getParamValue('axis', node, tensorMap, context);
const exclusive = getParamValue('exclusive', node, tensorMap, context);
const reverse = getParamValue('reverse', node, tensorMap, context);
return [tfOps.cumprod(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
}
case 'Cumsum': {
const axis = getParamValue('axis', node, tensorMap, context);
const exclusive = getParamValue('exclusive', node, tensorMap, context);
const reverse = getParamValue('reverse', node, tensorMap, context);
return [tfOps.cumsum(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)];
}
case 'Bincount':
const x = getParamValue('x', node, tensorMap, context);
const weights = getParamValue('weights', node, tensorMap, context);
const size = getParamValue('size', node, tensorMap, context);
return [tfOps.bincount(x, weights, size)];
case 'DenseBincount': {
const x = getParamValue('x', node, tensorMap, context);
const weights = getParamValue('weights', node, tensorMap, context);
const size = getParamValue('size', node, tensorMap, context);
const binaryOutput = getParamValue('binaryOutput', node, tensorMap, context);
return [tfOps.denseBincount(x, weights, size, binaryOutput)];
}
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};
export const CATEGORY = 'reduction';
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"reduction_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/reduction_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAY,EAAE;IACtC,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,KAAK,CAAC,CAAC;YACV,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,MAAM,CAAC,CAAC;YACX,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,IAAI,CACd,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,KAAK,CAAC,CAAC;YACV,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,KAAK,CAAC,CAAC;YACV,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,KAAK,CAAC,CAAC;YACV,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,KAAK,CAAC,CAAC;YACV,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,OAAO,CAAC,KAAK,CAAC,MAAM,CAChB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,OAAO,CAAC,KAAK,CAAC,MAAM,CAChB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,CAAC,CAAC,CAAC;SACpE;QACD,KAAK,MAAM,CAAC,CAAC;YACX,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,QAAQ,GACV,aAAa,CAAC,UAAU,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACnE,OAAO,CAAC,KAAK,CAAC,IAAI,CACd,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,QAAQ,CAAC,CAAC,CAAC;SAChB;QACD,KAAK,SAAS,CAAC,CAAC;YACd,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACpE,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YAClE,OAAO,CAAC,KAAK,CAAC,OAAO,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;SAC1B;QACD,KAAK,QAAQ,CAAC,CAAC;YACb,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAC9D,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YACpE,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YAClE,OAAO,CAAC,KAAK,CAAC,MAAM,CAChB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EAAE,IAAI,EAC5D,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;SAC1B;QACD,KAAK,UAAU;YACb,MAAM,CAAC,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YACnE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAE9D,OAAO,CAAC,KAAK,CAAC,QAAQ,CAAC,CAAC,EAAE,OAAO,EAAE,IAAI,CAAC,CAAC,CAAC;QAC5C,KAAK,eAAe,CAAC,CAAC;YACpB,MAAM,CAAC,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACzC,CAAC;YACb,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CACzC,CAAC;YACb,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YAE9D,MAAM,YAAY,GACd,aAAa,CAAC,cAAc,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAC/C,CAAC;YAEZ,OAAO,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,YAAY,CAAC,CAAC,CAAC;SAC9D;QACD;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,WAAW,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor, Tensor1D, Tensor2D} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getParamValue} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext): Tensor[] => {\n      switch (node.op) {\n        case 'Max': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.max(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'Mean': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.mean(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'Min': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.min(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'Sum': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.sum(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'All': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.all(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'Any': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.any(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'ArgMax': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number;\n          return [tfOps.argMax(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis)];\n        }\n        case 'ArgMin': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number;\n          return [tfOps.argMin(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis)];\n        }\n        case 'Prod': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number[];\n          const keepDims =\n              getParamValue('keepDims', node, tensorMap, context) as boolean;\n          return [tfOps.prod(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              keepDims)];\n        }\n        case 'Cumprod': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number;\n          const exclusive =\n              getParamValue('exclusive', node, tensorMap, context) as boolean;\n          const reverse =\n              getParamValue('reverse', node, tensorMap, context) as boolean;\n          return [tfOps.cumprod(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              exclusive, reverse)];\n        }\n        case 'Cumsum': {\n          const axis =\n              getParamValue('axis', node, tensorMap, context) as number;\n          const exclusive =\n              getParamValue('exclusive', node, tensorMap, context) as boolean;\n          const reverse =\n              getParamValue('reverse', node, tensorMap, context) as boolean;\n          return [tfOps.cumsum(\n              getParamValue('x', node, tensorMap, context) as Tensor, axis,\n              exclusive, reverse)];\n        }\n        case 'Bincount':\n          const x = getParamValue('x', node, tensorMap, context) as Tensor1D;\n          const weights =\n              getParamValue('weights', node, tensorMap, context) as Tensor1D;\n          const size =\n              getParamValue('size', node, tensorMap, context) as number;\n\n          return [tfOps.bincount(x, weights, size)];\n        case 'DenseBincount': {\n          const x = getParamValue('x', node, tensorMap, context) as Tensor1D |\n              Tensor2D;\n          const weights =\n              getParamValue('weights', node, tensorMap, context) as Tensor1D |\n              Tensor2D;\n          const size =\n              getParamValue('size', node, tensorMap, context) as number;\n\n          const binaryOutput =\n              getParamValue('binaryOutput', node, tensorMap, context) as\n              boolean;\n\n          return [tfOps.denseBincount(x, weights, size, binaryOutput)];\n        }\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'reduction';\n"]}