UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

63 lines 12.7 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line: no-imports-from-dist import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; import { getParamValue } from './utils'; export const executeOp = (node, tensorMap, context, ops = tfOps) => { switch (node.op) { case 'BatchMatMul': case 'BatchMatMulV2': case 'MatMul': return [ops.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))]; case 'Einsum': return [ops.einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))]; case 'Transpose': return [ops.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))]; case '_FusedMatMul': const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context); const isBiasAdd = extraOp === 'biasadd'; const isPrelu = activationFunc === 'prelu'; const numArgs = getParamValue('numArgs', node, tensorMap, context); const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context); if (isBiasAdd) { if (isPrelu && numArgs !== 2) { throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' + 'extra arguments: bias and alpha.'); } if (!isPrelu && numArgs !== 1) { throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.'); } } const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context); return [ops.fused.matMul({ a: getParamValue('a', node, tensorMap, context), b: getParamValue('b', node, tensorMap, context), transposeA: getParamValue('transposeA', node, tensorMap, context), transposeB: getParamValue('transposeB', node, tensorMap, context), bias: biasArg, activation: activationFunc, preluActivationWeights: preluArg, leakyreluAlpha })]; case 'MatrixBandPart': return [ops.linalg.bandPart(getParamValue('a', node, tensorMap, context), getParamValue('numLower', node, tensorMap, context), getParamValue('numUpper', node, tensorMap, context))]; default: throw TypeError(`Node type ${node.op} is not implemented`); } }; export const CATEGORY = 'matrices'; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibWF0cmljZXNfZXhlY3V0b3IuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvbnZlcnRlci9zcmMvb3BlcmF0aW9ucy9leGVjdXRvcnMvbWF0cmljZXNfZXhlY3V0b3IudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsaURBQWlEO0FBQ2pELE9BQU8sS0FBSyxLQUFLLE1BQU0sa0RBQWtELENBQUM7QUFNMUUsT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUV0QyxNQUFNLENBQUMsTUFBTSxTQUFTLEdBQ2xCLENBQUMsSUFBVSxFQUFFLFNBQTBCLEVBQUUsT0FBeUIsRUFDakUsR0FBRyxHQUFHLEtBQUssRUFBWSxFQUFFO0lBQ3hCLFFBQVEsSUFBSSxDQUFDLEVBQUUsRUFBRTtRQUNmLEtBQUssYUFBYSxDQUFDO1FBQ25CLEtBQUssZUFBZSxDQUFDO1FBQ3JCLEtBQUssUUFBUTtZQUNYLE9BQU8sQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUNkLGFBQWEsQ0FBQyxHQUFHLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWEsRUFDeEQsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxFQUN4RCxhQUFhLENBQUMsWUFBWSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFZLEVBQ2hFLGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3pDLENBQUMsQ0FBQyxDQUFDO1FBRXBCLEtBQUssUUFBUTtZQUNYLE9BQU8sQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUNkLGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsRUFDN0QsR0FBRyxhQUFhLENBQUMsU0FBUyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUN4QyxDQUFDLENBQUMsQ0FBQztRQUVyQixLQUFLLFdBQVc7WUFDZCxPQUFPLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FDakIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxFQUN0RCxhQUFhLENBQUMsTUFBTSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLENBQUMsQ0FBQyxDQUFDO1FBRXBFLEtBQUssY0FBYztZQUNqQixNQUFNLENBQUMsT0FBTyxFQUFFLGNBQWMsQ0FBQyxHQUMxQixhQUFhLENBQUMsVUFBVSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFjLENBQUM7WUFFdEUsTUFBTSxTQUFTLEdBQUcsT0FBTyxLQUFLLFNBQVMsQ0FBQztZQUN4QyxNQUFNLE9BQU8sR0FBRyxjQUFjLEtBQUssT0FBTyxDQUFDO1lBRTNDLE1BQU0sT0FBTyxHQUNSLGFBQWEsQ0FBQyxTQUFTLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVksQ0FBQztZQUNuRSxNQUFNLGNBQWMsR0FDaEIsYUFBYSxDQUFDLGdCQUFnQixFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUNsRCxDQUFDO1lBRVgsSUFBSSxTQUFTLEVBQUU7Z0JBQ2IsSUFBSSxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDNUIsTUFBTSxJQUFJLEtBQUssQ0FDWCxvREFBb0Q7d0JBQ3BELGtDQUFrQyxDQUFDLENBQUM7aUJBQ3pDO2dCQUNELElBQUksQ0FBQyxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDN0IsTUFBTSxJQUFJLEtBQUssQ0FDWCwrREFBK0QsQ0FBQyxDQUFDO2lCQUN0RTthQUNGO1lBQ0QsTUFBTSxDQUFDLE9BQU8sRUFBRSxRQUFRLENBQUMsR0FDckIsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxDQUFDO1lBQ2hFLE9BQU8sQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztvQkFDdkIsQ0FBQyxFQUFFLGFBQWEsQ0FBQyxHQUFHLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWE7b0JBQzNELENBQUMsRUFBRSxhQUFhLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhO29CQUMzRCxVQUFVLEVBQUUsYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDckQ7b0JBQ1gsVUFBVSxFQUFFLGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3JEO29CQUNYLElBQUksRUFBRSxPQUFPO29CQUNiLFVBQVUsRUFBRSxjQUF3QztvQkFDcEQsc0JBQXNCLEVBQUUsUUFBUTtvQkFDaEMsY0FBYztpQkFDZixDQUFDLENBQUMsQ0FBQztRQUVOLEtBQUssZ0JBQWdCO1lBQ25CLE9BQU8sQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLFFBQVEsQ0FDdkIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxFQUN4RCxhQUFhLENBQUMsVUFBVSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFXLEVBQzdELGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsQ0FBQyxDQUFDLENBQUM7UUFFdEU7WUFDRSxNQUFNLFNBQVMsQ0FBQyxhQUFhLElBQUksQ0FBQyxFQUFFLHFCQUFxQixDQUFDLENBQUM7S0FDOUQ7QUFDSCxDQUFDLENBQUM7QUFFTixNQUFNLENBQUMsTUFBTSxRQUFRLEdBQUcsVUFBVSxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1NjYWxhciwgVGVuc29yLCBUZW5zb3IyRH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbi8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8taW1wb3J0cy1mcm9tLWRpc3RcbmltcG9ydCAqIGFzIHRmT3BzIGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZS9kaXN0L29wcy9vcHNfZm9yX2NvbnZlcnRlcic7XG5cbmltcG9ydCB7TmFtZWRUZW5zb3JzTWFwfSBmcm9tICcuLi8uLi9kYXRhL3R5cGVzJztcbmltcG9ydCB7RXhlY3V0aW9uQ29udGV4dH0gZnJvbSAnLi4vLi4vZXhlY3V0b3IvZXhlY3V0aW9uX2NvbnRleHQnO1xuaW1wb3J0IHtJbnRlcm5hbE9wRXhlY3V0b3IsIE5vZGV9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtnZXRQYXJhbVZhbHVlfSBmcm9tICcuL3V0aWxzJztcblxuZXhwb3J0IGNvbnN0IGV4ZWN1dGVPcDogSW50ZXJuYWxPcEV4ZWN1dG9yID1cbiAgICAobm9kZTogTm9kZSwgdGVuc29yTWFwOiBOYW1lZFRlbnNvcnNNYXAsIGNvbnRleHQ6IEV4ZWN1dGlvbkNvbnRleHQsXG4gICAgIG9wcyA9IHRmT3BzKTogVGVuc29yW10gPT4ge1xuICAgICAgc3dpdGNoIChub2RlLm9wKSB7XG4gICAgICAgIGNhc2UgJ0JhdGNoTWF0TXVsJzpcbiAgICAgICAgY2FzZSAnQmF0Y2hNYXRNdWxWMic6XG4gICAgICAgIGNhc2UgJ01hdE11bCc6XG4gICAgICAgICAgcmV0dXJuIFtvcHMubWF0TXVsKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnYicsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIGJvb2xlYW4sXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICBib29sZWFuKV07XG5cbiAgICAgICAgY2FzZSAnRWluc3VtJzpcbiAgICAgICAgICByZXR1cm4gW29wcy5laW5zdW0oXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2VxdWF0aW9uJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBzdHJpbmcsXG4gICAgICAgICAgICAgIC4uLmdldFBhcmFtVmFsdWUoJ3RlbnNvcnMnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICBUZW5zb3JbXSldO1xuXG4gICAgICAgIGNhc2UgJ1RyYW5zcG9zZSc6XG4gICAgICAgICAgcmV0dXJuIFtvcHMudHJhbnNwb3NlKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCd4Jywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3Blcm0nLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIG51bWJlcltdKV07XG5cbiAgICAgICAgY2FzZSAnX0Z1c2VkTWF0TXVsJzpcbiAgICAgICAgICBjb25zdCBbZXh0cmFPcCwgYWN0aXZhdGlvbkZ1bmNdID1cbiAgICAgICAgICAgICAgKGdldFBhcmFtVmFsdWUoJ2Z1c2VkT3BzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBzdHJpbmdbXSk7XG5cbiAgICAgICAgICBjb25zdCBpc0JpYXNBZGQgPSBleHRyYU9wID09PSAnYmlhc2FkZCc7XG4gICAgICAgICAgY29uc3QgaXNQcmVsdSA9IGFjdGl2YXRpb25GdW5jID09PSAncHJlbHUnO1xuXG4gICAgICAgICAgY29uc3QgbnVtQXJncyA9XG4gICAgICAgICAgICAgIChnZXRQYXJhbVZhbHVlKCdudW1BcmdzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBudW1iZXIpO1xuICAgICAgICAgIGNvbnN0IGxlYWt5cmVsdUFscGhhID1cbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnbGVha3lyZWx1QWxwaGEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgIG51bWJlcjtcblxuICAgICAgICAgIGlmIChpc0JpYXNBZGQpIHtcbiAgICAgICAgICAgIGlmIChpc1ByZWx1ICYmIG51bUFyZ3MgIT09IDIpIHtcbiAgICAgICAgICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgICAgICAgICAgJ0Z1c2VkIE1hdE11bCB3aXRoIEJpYXNBZGQgYW5kIFByZWx1IG11c3QgaGF2ZSB0d28gJyArXG4gICAgICAgICAgICAgICAgICAnZXh0cmEgYXJndW1lbnRzOiBiaWFzIGFuZCBhbHBoYS4nKTtcbiAgICAgICAgICAgIH1cbiAgICAgICAgICAgIGlmICghaXNQcmVsdSAmJiBudW1BcmdzICE9PSAxKSB7XG4gICAgICAgICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgICAgICAgICdGdXNlZCBNYXRNdWwgd2l0aCBCaWFzQWRkIG11c3QgaGF2ZSBvbmUgZXh0cmEgYXJndW1lbnQ6IGJpYXMuJyk7XG4gICAgICAgICAgICB9XG4gICAgICAgICAgfVxuICAgICAgICAgIGNvbnN0IFtiaWFzQXJnLCBwcmVsdUFyZ10gPVxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhcmdzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3JbXTtcbiAgICAgICAgICByZXR1cm4gW29wcy5mdXNlZC5tYXRNdWwoe1xuICAgICAgICAgICAgYTogZ2V0UGFyYW1WYWx1ZSgnYScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICBiOiBnZXRQYXJhbVZhbHVlKCdiJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgIHRyYW5zcG9zZUE6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIHRyYW5zcG9zZUI6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIGJpYXM6IGJpYXNBcmcsXG4gICAgICAgICAgICBhY3RpdmF0aW9uOiBhY3RpdmF0aW9uRnVuYyBhcyB0Zk9wcy5mdXNlZC5BY3RpdmF0aW9uLFxuICAgICAgICAgICAgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0czogcHJlbHVBcmcsXG4gICAgICAgICAgICBsZWFreXJlbHVBbHBoYVxuICAgICAgICAgIH0pXTtcblxuICAgICAgICBjYXNlICdNYXRyaXhCYW5kUGFydCc6XG4gICAgICAgICAgcmV0dXJuIFtvcHMubGluYWxnLmJhbmRQYXJ0KFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnbnVtTG93ZXInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFNjYWxhcixcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnbnVtVXBwZXInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFNjYWxhcildO1xuXG4gICAgICAgIGRlZmF1bHQ6XG4gICAgICAgICAgdGhyb3cgVHlwZUVycm9yKGBOb2RlIHR5cGUgJHtub2RlLm9wfSBpcyBub3QgaW1wbGVtZW50ZWRgKTtcbiAgICAgIH1cbiAgICB9O1xuXG5leHBvcnQgY29uc3QgQ0FURUdPUlkgPSAnbWF0cmljZXMnO1xuIl19