@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
61 lines • 11.7 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
import { getParamValue } from './utils';
export const executeOp = (node, tensorMap, context) => {
switch (node.op) {
case 'BatchMatMul':
case 'BatchMatMulV2':
case 'MatMul':
return [tfOps.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
case 'Einsum':
return [tfOps.einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))];
case 'Transpose':
return [tfOps.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
case '_FusedMatMul':
const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
const isBiasAdd = extraOp === 'biasadd';
const isPrelu = activationFunc === 'prelu';
const numArgs = getParamValue('numArgs', node, tensorMap, context);
const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
if (isBiasAdd) {
if (isPrelu && numArgs !== 2) {
throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' +
'extra arguments: bias and alpha.');
}
if (!isPrelu && numArgs !== 1) {
throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
}
}
const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
return [tfOps.fused.matMul({
a: getParamValue('a', node, tensorMap, context),
b: getParamValue('b', node, tensorMap, context),
transposeA: getParamValue('transposeA', node, tensorMap, context),
transposeB: getParamValue('transposeB', node, tensorMap, context),
bias: biasArg,
activation: activationFunc,
preluActivationWeights: preluArg,
leakyreluAlpha
})];
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};
export const CATEGORY = 'matrices';
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibWF0cmljZXNfZXhlY3V0b3IuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvbnZlcnRlci9zcmMvb3BlcmF0aW9ucy9leGVjdXRvcnMvbWF0cmljZXNfZXhlY3V0b3IudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsaURBQWlEO0FBQ2pELE9BQU8sS0FBSyxLQUFLLE1BQU0sa0RBQWtELENBQUM7QUFNMUUsT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUV0QyxNQUFNLENBQUMsTUFBTSxTQUFTLEdBQ2xCLENBQUMsSUFBVSxFQUFFLFNBQTBCLEVBQ3RDLE9BQXlCLEVBQVksRUFBRTtJQUN0QyxRQUFRLElBQUksQ0FBQyxFQUFFLEVBQUU7UUFDZixLQUFLLGFBQWEsQ0FBQztRQUNuQixLQUFLLGVBQWUsQ0FBQztRQUNyQixLQUFLLFFBQVE7WUFDWCxPQUFPLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FDaEIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxFQUN4RCxhQUFhLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLEVBQ3hELGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVksRUFDaEUsYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDekMsQ0FBQyxDQUFDLENBQUM7UUFFcEIsS0FBSyxRQUFRO1lBQ1gsT0FBTyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQ2hCLGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsRUFDN0QsR0FBRyxhQUFhLENBQUMsU0FBUyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUN4QyxDQUFDLENBQUMsQ0FBQztRQUVyQixLQUFLLFdBQVc7WUFDZCxPQUFPLENBQUMsS0FBSyxDQUFDLFNBQVMsQ0FDbkIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxFQUN0RCxhQUFhLENBQUMsTUFBTSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLENBQUMsQ0FBQyxDQUFDO1FBRXBFLEtBQUssY0FBYztZQUNqQixNQUFNLENBQUMsT0FBTyxFQUFFLGNBQWMsQ0FBQyxHQUMxQixhQUFhLENBQUMsVUFBVSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFjLENBQUM7WUFFdEUsTUFBTSxTQUFTLEdBQUcsT0FBTyxLQUFLLFNBQVMsQ0FBQztZQUN4QyxNQUFNLE9BQU8sR0FBRyxjQUFjLEtBQUssT0FBTyxDQUFDO1lBRTNDLE1BQU0sT0FBTyxHQUNSLGFBQWEsQ0FBQyxTQUFTLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVksQ0FBQztZQUNuRSxNQUFNLGNBQWMsR0FDaEIsYUFBYSxDQUFDLGdCQUFnQixFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUNsRCxDQUFDO1lBRVgsSUFBSSxTQUFTLEVBQUU7Z0JBQ2IsSUFBSSxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDNUIsTUFBTSxJQUFJLEtBQUssQ0FDWCxvREFBb0Q7d0JBQ3BELGtDQUFrQyxDQUFDLENBQUM7aUJBQ3pDO2dCQUNELElBQUksQ0FBQyxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDN0IsTUFBTSxJQUFJLEtBQUssQ0FDWCwrREFBK0QsQ0FBQyxDQUFDO2lCQUN0RTthQUNGO1lBQ0QsTUFBTSxDQUFDLE9BQU8sRUFBRSxRQUFRLENBQUMsR0FDckIsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxDQUFDO1lBQ2hFLE9BQU8sQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztvQkFDekIsQ0FBQyxFQUFFLGFBQWEsQ0FBQyxHQUFHLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWE7b0JBQzNELENBQUMsRUFBRSxhQUFhLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhO29CQUMzRCxVQUFVLEVBQUUsYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDckQ7b0JBQ1gsVUFBVSxFQUFFLGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3JEO29CQUNYLElBQUksRUFBRSxPQUFPO29CQUNiLFVBQVUsRUFBRSxjQUF3QztvQkFDcEQsc0JBQXNCLEVBQUUsUUFBUTtvQkFDaEMsY0FBYztpQkFDZixDQUFDLENBQUMsQ0FBQztRQUVOO1lBQ0UsTUFBTSxTQUFTLENBQUMsYUFBYSxJQUFJLENBQUMsRUFBRSxxQkFBcUIsQ0FBQyxDQUFDO0tBQzlEO0FBQ0gsQ0FBQyxDQUFDO0FBRU4sTUFBTSxDQUFDLE1BQU0sUUFBUSxHQUFHLFVBQVUsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtUZW5zb3IsIFRlbnNvcjJEfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOiBuby1pbXBvcnRzLWZyb20tZGlzdFxuaW1wb3J0ICogYXMgdGZPcHMgZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlL2Rpc3Qvb3BzL29wc19mb3JfY29udmVydGVyJztcblxuaW1wb3J0IHtOYW1lZFRlbnNvcnNNYXB9IGZyb20gJy4uLy4uL2RhdGEvdHlwZXMnO1xuaW1wb3J0IHtFeGVjdXRpb25Db250ZXh0fSBmcm9tICcuLi8uLi9leGVjdXRvci9leGVjdXRpb25fY29udGV4dCc7XG5pbXBvcnQge0ludGVybmFsT3BFeGVjdXRvciwgTm9kZX0gZnJvbSAnLi4vdHlwZXMnO1xuXG5pbXBvcnQge2dldFBhcmFtVmFsdWV9IGZyb20gJy4vdXRpbHMnO1xuXG5leHBvcnQgY29uc3QgZXhlY3V0ZU9wOiBJbnRlcm5hbE9wRXhlY3V0b3IgPVxuICAgIChub2RlOiBOb2RlLCB0ZW5zb3JNYXA6IE5hbWVkVGVuc29yc01hcCxcbiAgICAgY29udGV4dDogRXhlY3V0aW9uQ29udGV4dCk6IFRlbnNvcltdID0+IHtcbiAgICAgIHN3aXRjaCAobm9kZS5vcCkge1xuICAgICAgICBjYXNlICdCYXRjaE1hdE11bCc6XG4gICAgICAgIGNhc2UgJ0JhdGNoTWF0TXVsVjInOlxuICAgICAgICBjYXNlICdNYXRNdWwnOlxuICAgICAgICAgIHJldHVybiBbdGZPcHMubWF0TXVsKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnYicsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIGJvb2xlYW4sXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICBib29sZWFuKV07XG5cbiAgICAgICAgY2FzZSAnRWluc3VtJzpcbiAgICAgICAgICByZXR1cm4gW3RmT3BzLmVpbnN1bShcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnZXF1YXRpb24nLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIHN0cmluZyxcbiAgICAgICAgICAgICAgLi4uZ2V0UGFyYW1WYWx1ZSgndGVuc29ycycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXNcbiAgICAgICAgICAgICAgICAgIFRlbnNvcltdKV07XG5cbiAgICAgICAgY2FzZSAnVHJhbnNwb3NlJzpcbiAgICAgICAgICByZXR1cm4gW3RmT3BzLnRyYW5zcG9zZShcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgneCcsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yLFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdwZXJtJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBudW1iZXJbXSldO1xuXG4gICAgICAgIGNhc2UgJ19GdXNlZE1hdE11bCc6XG4gICAgICAgICAgY29uc3QgW2V4dHJhT3AsIGFjdGl2YXRpb25GdW5jXSA9XG4gICAgICAgICAgICAgIChnZXRQYXJhbVZhbHVlKCdmdXNlZE9wcycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgc3RyaW5nW10pO1xuXG4gICAgICAgICAgY29uc3QgaXNCaWFzQWRkID0gZXh0cmFPcCA9PT0gJ2JpYXNhZGQnO1xuICAgICAgICAgIGNvbnN0IGlzUHJlbHUgPSBhY3RpdmF0aW9uRnVuYyA9PT0gJ3ByZWx1JztcblxuICAgICAgICAgIGNvbnN0IG51bUFyZ3MgPVxuICAgICAgICAgICAgICAoZ2V0UGFyYW1WYWx1ZSgnbnVtQXJncycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgbnVtYmVyKTtcbiAgICAgICAgICBjb25zdCBsZWFreXJlbHVBbHBoYSA9XG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2xlYWt5cmVsdUFscGhhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhc1xuICAgICAgICAgICAgICBudW1iZXI7XG5cbiAgICAgICAgICBpZiAoaXNCaWFzQWRkKSB7XG4gICAgICAgICAgICBpZiAoaXNQcmVsdSAmJiBudW1BcmdzICE9PSAyKSB7XG4gICAgICAgICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgICAgICAgICdGdXNlZCBNYXRNdWwgd2l0aCBCaWFzQWRkIGFuZCBQcmVsdSBtdXN0IGhhdmUgdHdvICcgK1xuICAgICAgICAgICAgICAgICAgJ2V4dHJhIGFyZ3VtZW50czogYmlhcyBhbmQgYWxwaGEuJyk7XG4gICAgICAgICAgICB9XG4gICAgICAgICAgICBpZiAoIWlzUHJlbHUgJiYgbnVtQXJncyAhPT0gMSkge1xuICAgICAgICAgICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgICAgICAgICAnRnVzZWQgTWF0TXVsIHdpdGggQmlhc0FkZCBtdXN0IGhhdmUgb25lIGV4dHJhIGFyZ3VtZW50OiBiaWFzLicpO1xuICAgICAgICAgICAgfVxuICAgICAgICAgIH1cbiAgICAgICAgICBjb25zdCBbYmlhc0FyZywgcHJlbHVBcmddID1cbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnYXJncycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yW107XG4gICAgICAgICAgcmV0dXJuIFt0Zk9wcy5mdXNlZC5tYXRNdWwoe1xuICAgICAgICAgICAgYTogZ2V0UGFyYW1WYWx1ZSgnYScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICBiOiBnZXRQYXJhbVZhbHVlKCdiJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgIHRyYW5zcG9zZUE6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIHRyYW5zcG9zZUI6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIGJpYXM6IGJpYXNBcmcsXG4gICAgICAgICAgICBhY3RpdmF0aW9uOiBhY3RpdmF0aW9uRnVuYyBhcyB0Zk9wcy5mdXNlZC5BY3RpdmF0aW9uLFxuICAgICAgICAgICAgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0czogcHJlbHVBcmcsXG4gICAgICAgICAgICBsZWFreXJlbHVBbHBoYVxuICAgICAgICAgIH0pXTtcblxuICAgICAgICBkZWZhdWx0OlxuICAgICAgICAgIHRocm93IFR5cGVFcnJvcihgTm9kZSB0eXBlICR7bm9kZS5vcH0gaXMgbm90IGltcGxlbWVudGVkYCk7XG4gICAgICB9XG4gICAgfTtcblxuZXhwb3J0IGNvbnN0IENBVEVHT1JZID0gJ21hdHJpY2VzJztcbiJdfQ==