@tensorflow/tfjs
Version:
An open-source machine learning framework.
110 lines • 4.88 kB
JavaScript
;
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
exports.getCustomConverterOpsModule = exports.getCustomModuleString = void 0;
var util_1 = require("./util");
function getCustomModuleString(config, moduleProvider) {
var kernels = config.kernels, backends = config.backends, forwardModeOnly = config.forwardModeOnly, models = config.models;
var tfjs = [(0, util_1.getPreamble)()];
// A custom tfjs module
addLine(tfjs, moduleProvider.importCoreStr(forwardModeOnly));
if (models.length > 0) {
// A model.json has been passed.
addLine(tfjs, moduleProvider.importConverterStr());
}
for (var _i = 0, backends_1 = backends; _i < backends_1.length; _i++) {
var backend = backends_1[_i];
addLine(tfjs, "\n//backend = ".concat(backend));
addLine(tfjs, moduleProvider.importBackendStr(backend));
for (var _a = 0, kernels_1 = kernels; _a < kernels_1.length; _a++) {
var kernelName = kernels_1[_a];
var kernelImport = moduleProvider.importKernelStr(kernelName, backend);
if (!moduleProvider.validateImportPath(kernelImport.importPath)) {
console.warn('WARNING:', "Import path '".concat(kernelImport.importPath, "' cannot be resolved. Skipping..."));
continue;
}
addLine(tfjs, kernelImport.importStatement);
addLine(tfjs, registerKernelStr(kernelImport.kernelConfigId));
}
}
if (!forwardModeOnly) {
addLine(tfjs, "\n//Gradients");
for (var _b = 0, kernels_2 = kernels; _b < kernels_2.length; _b++) {
var kernelName = kernels_2[_b];
var gradImport = moduleProvider.importGradientConfigStr(kernelName);
if (!moduleProvider.validateImportPath(gradImport.importPath)) {
console.warn('WARNING:', "Import path '".concat(gradImport.importPath, "' cannot be resolved. Skipping..."));
continue;
}
addLine(tfjs, gradImport.importStatement);
addLine(tfjs, registerGradientConfigStr(gradImport.gradConfigId));
}
}
// A custom tfjs core module for imports within tfjs packages
var core = [(0, util_1.getPreamble)()];
addLine(core, moduleProvider.importCoreStr(forwardModeOnly));
return {
core: core.join('\n'),
tfjs: tfjs.join('\n'),
};
}
exports.getCustomModuleString = getCustomModuleString;
function getCustomConverterOpsModule(ops, moduleProvider) {
var result = ['// This file is autogenerated\n'];
// Separate namespaced apis from non namespaced ones as they require a
// different export pattern that treats each namespace as a whole.
var flatOps = [];
var namespacedOps = {};
for (var _i = 0, ops_1 = ops; _i < ops_1.length; _i++) {
var opSymbol = ops_1[_i];
if (opSymbol.match(/\./)) {
var parts = opSymbol.split(/\./);
var namespace = parts[0];
var opName = parts[1];
if (namespacedOps[namespace] == null) {
namespacedOps[namespace] = [];
}
namespacedOps[namespace].push(opName);
}
else {
flatOps.push(opSymbol);
}
}
// Group the namespaced symbols by namespace
for (var _a = 0, _b = Object.keys(namespacedOps); _a < _b.length; _a++) {
var namespace = _b[_a];
var opSymbols = namespacedOps[namespace];
result.push(moduleProvider.importNamespacedOpsForConverterStr(namespace, opSymbols));
}
for (var _c = 0, flatOps_1 = flatOps; _c < flatOps_1.length; _c++) {
var opSymbol = flatOps_1[_c];
result.push(moduleProvider.importOpForConverterStr(opSymbol));
}
return result.join('\n');
}
exports.getCustomConverterOpsModule = getCustomConverterOpsModule;
function addLine(target, line) {
target.push(line);
}
function registerKernelStr(kernelConfigId) {
return "registerKernel(".concat(kernelConfigId, ");");
}
function registerGradientConfigStr(gradConfigId) {
return "registerGradient(".concat(gradConfigId, ");");
}
//# sourceMappingURL=custom_module.js.map