UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

61 lines 11.7 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line: no-imports-from-dist import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; import { getParamValue } from './utils'; export const executeOp = (node, tensorMap, context) => { switch (node.op) { case 'BatchMatMul': case 'BatchMatMulV2': case 'MatMul': return [tfOps.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))]; case 'Einsum': return [tfOps.einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))]; case 'Transpose': return [tfOps.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))]; case '_FusedMatMul': const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context); const isBiasAdd = extraOp === 'biasadd'; const isPrelu = activationFunc === 'prelu'; const numArgs = getParamValue('numArgs', node, tensorMap, context); const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context); if (isBiasAdd) { if (isPrelu && numArgs !== 2) { throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' + 'extra arguments: bias and alpha.'); } if (!isPrelu && numArgs !== 1) { throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.'); } } const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context); return [tfOps.fused.matMul({ a: getParamValue('a', node, tensorMap, context), b: getParamValue('b', node, tensorMap, context), transposeA: getParamValue('transposeA', node, tensorMap, context), transposeB: getParamValue('transposeB', node, tensorMap, context), bias: biasArg, activation: activationFunc, preluActivationWeights: preluArg, leakyreluAlpha })]; default: throw TypeError(`Node type ${node.op} is not implemented`); } }; export const CATEGORY = 'matrices'; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibWF0cmljZXNfZXhlY3V0b3IuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvbnZlcnRlci9zcmMvb3BlcmF0aW9ucy9leGVjdXRvcnMvbWF0cmljZXNfZXhlY3V0b3IudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsaURBQWlEO0FBQ2pELE9BQU8sS0FBSyxLQUFLLE1BQU0sa0RBQWtELENBQUM7QUFNMUUsT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUV0QyxNQUFNLENBQUMsTUFBTSxTQUFTLEdBQ2xCLENBQUMsSUFBVSxFQUFFLFNBQTBCLEVBQ3RDLE9BQXlCLEVBQVksRUFBRTtJQUN0QyxRQUFRLElBQUksQ0FBQyxFQUFFLEVBQUU7UUFDZixLQUFLLGFBQWEsQ0FBQztRQUNuQixLQUFLLGVBQWUsQ0FBQztRQUNyQixLQUFLLFFBQVE7WUFDWCxPQUFPLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FDaEIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxFQUN4RCxhQUFhLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLEVBQ3hELGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVksRUFDaEUsYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDekMsQ0FBQyxDQUFDLENBQUM7UUFFcEIsS0FBSyxRQUFRO1lBQ1gsT0FBTyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQ2hCLGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsRUFDN0QsR0FBRyxhQUFhLENBQUMsU0FBUyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUN4QyxDQUFDLENBQUMsQ0FBQztRQUVyQixLQUFLLFdBQVc7WUFDZCxPQUFPLENBQUMsS0FBSyxDQUFDLFNBQVMsQ0FDbkIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxFQUN0RCxhQUFhLENBQUMsTUFBTSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLENBQUMsQ0FBQyxDQUFDO1FBRXBFLEtBQUssY0FBYztZQUNqQixNQUFNLENBQUMsT0FBTyxFQUFFLGNBQWMsQ0FBQyxHQUMxQixhQUFhLENBQUMsVUFBVSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFjLENBQUM7WUFFdEUsTUFBTSxTQUFTLEdBQUcsT0FBTyxLQUFLLFNBQVMsQ0FBQztZQUN4QyxNQUFNLE9BQU8sR0FBRyxjQUFjLEtBQUssT0FBTyxDQUFDO1lBRTNDLE1BQU0sT0FBTyxHQUNSLGFBQWEsQ0FBQyxTQUFTLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVksQ0FBQztZQUNuRSxNQUFNLGNBQWMsR0FDaEIsYUFBYSxDQUFDLGdCQUFnQixFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUNsRCxDQUFDO1lBRVgsSUFBSSxTQUFTLEVBQUU7Z0JBQ2IsSUFBSSxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDNUIsTUFBTSxJQUFJLEtBQUssQ0FDWCxvREFBb0Q7d0JBQ3BELGtDQUFrQyxDQUFDLENBQUM7aUJBQ3pDO2dCQUNELElBQUksQ0FBQyxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDN0IsTUFBTSxJQUFJLEtBQUssQ0FDWCwrREFBK0QsQ0FBQyxDQUFDO2lCQUN0RTthQUNGO1lBQ0QsTUFBTSxDQUFDLE9BQU8sRUFBRSxRQUFRLENBQUMsR0FDckIsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxDQUFDO1lBQ2hFLE9BQU8sQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztvQkFDekIsQ0FBQyxFQUFFLGFBQWEsQ0FBQyxHQUFHLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWE7b0JBQzNELENBQUMsRUFBRSxhQUFhLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhO29CQUMzRCxVQUFVLEVBQUUsYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDckQ7b0JBQ1gsVUFBVSxFQUFFLGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3JEO29CQUNYLElBQUksRUFBRSxPQUFPO29CQUNiLFVBQVUsRUFBRSxjQUF3QztvQkFDcEQsc0JBQXNCLEVBQUUsUUFBUTtvQkFDaEMsY0FBYztpQkFDZixDQUFDLENBQUMsQ0FBQztRQUVOO1lBQ0UsTUFBTSxTQUFTLENBQUMsYUFBYSxJQUFJLENBQUMsRUFBRSxxQkFBcUIsQ0FBQyxDQUFDO0tBQzlEO0FBQ0gsQ0FBQyxDQUFDO0FBRU4sTUFBTSxDQUFDLE1BQU0sUUFBUSxHQUFHLFVBQVUsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtUZW5zb3IsIFRlbnNvcjJEfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOiBuby1pbXBvcnRzLWZyb20tZGlzdFxuaW1wb3J0ICogYXMgdGZPcHMgZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlL2Rpc3Qvb3BzL29wc19mb3JfY29udmVydGVyJztcblxuaW1wb3J0IHtOYW1lZFRlbnNvcnNNYXB9IGZyb20gJy4uLy4uL2RhdGEvdHlwZXMnO1xuaW1wb3J0IHtFeGVjdXRpb25Db250ZXh0fSBmcm9tICcuLi8uLi9leGVjdXRvci9leGVjdXRpb25fY29udGV4dCc7XG5pbXBvcnQge0ludGVybmFsT3BFeGVjdXRvciwgTm9kZX0gZnJvbSAnLi4vdHlwZXMnO1xuXG5pbXBvcnQge2dldFBhcmFtVmFsdWV9IGZyb20gJy4vdXRpbHMnO1xuXG5leHBvcnQgY29uc3QgZXhlY3V0ZU9wOiBJbnRlcm5hbE9wRXhlY3V0b3IgPVxuICAgIChub2RlOiBOb2RlLCB0ZW5zb3JNYXA6IE5hbWVkVGVuc29yc01hcCxcbiAgICAgY29udGV4dDogRXhlY3V0aW9uQ29udGV4dCk6IFRlbnNvcltdID0+IHtcbiAgICAgIHN3aXRjaCAobm9kZS5vcCkge1xuICAgICAgICBjYXNlICdCYXRjaE1hdE11bCc6XG4gICAgICAgIGNhc2UgJ0JhdGNoTWF0TXVsVjInOlxuICAgICAgICBjYXNlICdNYXRNdWwnOlxuICAgICAgICAgIHJldHVybiBbdGZPcHMubWF0TXVsKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnYicsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIGJvb2xlYW4sXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICBib29sZWFuKV07XG5cbiAgICAgICAgY2FzZSAnRWluc3VtJzpcbiAgICAgICAgICByZXR1cm4gW3RmT3BzLmVpbnN1bShcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnZXF1YXRpb24nLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIHN0cmluZyxcbiAgICAgICAgICAgICAgLi4uZ2V0UGFyYW1WYWx1ZSgndGVuc29ycycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXNcbiAgICAgICAgICAgICAgICAgIFRlbnNvcltdKV07XG5cbiAgICAgICAgY2FzZSAnVHJhbnNwb3NlJzpcbiAgICAgICAgICByZXR1cm4gW3RmT3BzLnRyYW5zcG9zZShcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgneCcsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yLFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdwZXJtJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBudW1iZXJbXSldO1xuXG4gICAgICAgIGNhc2UgJ19GdXNlZE1hdE11bCc6XG4gICAgICAgICAgY29uc3QgW2V4dHJhT3AsIGFjdGl2YXRpb25GdW5jXSA9XG4gICAgICAgICAgICAgIChnZXRQYXJhbVZhbHVlKCdmdXNlZE9wcycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgc3RyaW5nW10pO1xuXG4gICAgICAgICAgY29uc3QgaXNCaWFzQWRkID0gZXh0cmFPcCA9PT0gJ2JpYXNhZGQnO1xuICAgICAgICAgIGNvbnN0IGlzUHJlbHUgPSBhY3RpdmF0aW9uRnVuYyA9PT0gJ3ByZWx1JztcblxuICAgICAgICAgIGNvbnN0IG51bUFyZ3MgPVxuICAgICAgICAgICAgICAoZ2V0UGFyYW1WYWx1ZSgnbnVtQXJncycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgbnVtYmVyKTtcbiAgICAgICAgICBjb25zdCBsZWFreXJlbHVBbHBoYSA9XG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2xlYWt5cmVsdUFscGhhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhc1xuICAgICAgICAgICAgICBudW1iZXI7XG5cbiAgICAgICAgICBpZiAoaXNCaWFzQWRkKSB7XG4gICAgICAgICAgICBpZiAoaXNQcmVsdSAmJiBudW1BcmdzICE9PSAyKSB7XG4gICAgICAgICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgICAgICAgICdGdXNlZCBNYXRNdWwgd2l0aCBCaWFzQWRkIGFuZCBQcmVsdSBtdXN0IGhhdmUgdHdvICcgK1xuICAgICAgICAgICAgICAgICAgJ2V4dHJhIGFyZ3VtZW50czogYmlhcyBhbmQgYWxwaGEuJyk7XG4gICAgICAgICAgICB9XG4gICAgICAgICAgICBpZiAoIWlzUHJlbHUgJiYgbnVtQXJncyAhPT0gMSkge1xuICAgICAgICAgICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgICAgICAgICAnRnVzZWQgTWF0TXVsIHdpdGggQmlhc0FkZCBtdXN0IGhhdmUgb25lIGV4dHJhIGFyZ3VtZW50OiBiaWFzLicpO1xuICAgICAgICAgICAgfVxuICAgICAgICAgIH1cbiAgICAgICAgICBjb25zdCBbYmlhc0FyZywgcHJlbHVBcmddID1cbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnYXJncycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yW107XG4gICAgICAgICAgcmV0dXJuIFt0Zk9wcy5mdXNlZC5tYXRNdWwoe1xuICAgICAgICAgICAgYTogZ2V0UGFyYW1WYWx1ZSgnYScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICBiOiBnZXRQYXJhbVZhbHVlKCdiJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgIHRyYW5zcG9zZUE6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIHRyYW5zcG9zZUI6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIGJpYXM6IGJpYXNBcmcsXG4gICAgICAgICAgICBhY3RpdmF0aW9uOiBhY3RpdmF0aW9uRnVuYyBhcyB0Zk9wcy5mdXNlZC5BY3RpdmF0aW9uLFxuICAgICAgICAgICAgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0czogcHJlbHVBcmcsXG4gICAgICAgICAgICBsZWFreXJlbHVBbHBoYVxuICAgICAgICAgIH0pXTtcblxuICAgICAgICBkZWZhdWx0OlxuICAgICAgICAgIHRocm93IFR5cGVFcnJvcihgTm9kZSB0eXBlICR7bm9kZS5vcH0gaXMgbm90IGltcGxlbWVudGVkYCk7XG4gICAgICB9XG4gICAgfTtcblxuZXhwb3J0IGNvbnN0IENBVEVHT1JZID0gJ21hdHJpY2VzJztcbiJdfQ==