UNPKG

@tensorflow/tfjs

Version:

An open-source machine learning framework.

110 lines 4.88 kB
"use strict"; /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ Object.defineProperty(exports, "__esModule", { value: true }); exports.getCustomConverterOpsModule = exports.getCustomModuleString = void 0; var util_1 = require("./util"); function getCustomModuleString(config, moduleProvider) { var kernels = config.kernels, backends = config.backends, forwardModeOnly = config.forwardModeOnly, models = config.models; var tfjs = [(0, util_1.getPreamble)()]; // A custom tfjs module addLine(tfjs, moduleProvider.importCoreStr(forwardModeOnly)); if (models.length > 0) { // A model.json has been passed. addLine(tfjs, moduleProvider.importConverterStr()); } for (var _i = 0, backends_1 = backends; _i < backends_1.length; _i++) { var backend = backends_1[_i]; addLine(tfjs, "\n//backend = ".concat(backend)); addLine(tfjs, moduleProvider.importBackendStr(backend)); for (var _a = 0, kernels_1 = kernels; _a < kernels_1.length; _a++) { var kernelName = kernels_1[_a]; var kernelImport = moduleProvider.importKernelStr(kernelName, backend); if (!moduleProvider.validateImportPath(kernelImport.importPath)) { console.warn('WARNING:', "Import path '".concat(kernelImport.importPath, "' cannot be resolved. Skipping...")); continue; } addLine(tfjs, kernelImport.importStatement); addLine(tfjs, registerKernelStr(kernelImport.kernelConfigId)); } } if (!forwardModeOnly) { addLine(tfjs, "\n//Gradients"); for (var _b = 0, kernels_2 = kernels; _b < kernels_2.length; _b++) { var kernelName = kernels_2[_b]; var gradImport = moduleProvider.importGradientConfigStr(kernelName); if (!moduleProvider.validateImportPath(gradImport.importPath)) { console.warn('WARNING:', "Import path '".concat(gradImport.importPath, "' cannot be resolved. Skipping...")); continue; } addLine(tfjs, gradImport.importStatement); addLine(tfjs, registerGradientConfigStr(gradImport.gradConfigId)); } } // A custom tfjs core module for imports within tfjs packages var core = [(0, util_1.getPreamble)()]; addLine(core, moduleProvider.importCoreStr(forwardModeOnly)); return { core: core.join('\n'), tfjs: tfjs.join('\n'), }; } exports.getCustomModuleString = getCustomModuleString; function getCustomConverterOpsModule(ops, moduleProvider) { var result = ['// This file is autogenerated\n']; // Separate namespaced apis from non namespaced ones as they require a // different export pattern that treats each namespace as a whole. var flatOps = []; var namespacedOps = {}; for (var _i = 0, ops_1 = ops; _i < ops_1.length; _i++) { var opSymbol = ops_1[_i]; if (opSymbol.match(/\./)) { var parts = opSymbol.split(/\./); var namespace = parts[0]; var opName = parts[1]; if (namespacedOps[namespace] == null) { namespacedOps[namespace] = []; } namespacedOps[namespace].push(opName); } else { flatOps.push(opSymbol); } } // Group the namespaced symbols by namespace for (var _a = 0, _b = Object.keys(namespacedOps); _a < _b.length; _a++) { var namespace = _b[_a]; var opSymbols = namespacedOps[namespace]; result.push(moduleProvider.importNamespacedOpsForConverterStr(namespace, opSymbols)); } for (var _c = 0, flatOps_1 = flatOps; _c < flatOps_1.length; _c++) { var opSymbol = flatOps_1[_c]; result.push(moduleProvider.importOpForConverterStr(opSymbol)); } return result.join('\n'); } exports.getCustomConverterOpsModule = getCustomConverterOpsModule; function addLine(target, line) { target.push(line); } function registerKernelStr(kernelConfigId) { return "registerKernel(".concat(kernelConfigId, ");"); } function registerGradientConfigStr(gradConfigId) { return "registerGradient(".concat(gradConfigId, ");"); } //# sourceMappingURL=custom_module.js.map