@tensorflow/tfjs
Version:
An open-source machine learning framework.
144 lines • 6.75 kB
JavaScript
;
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
exports.esmImportProvider = exports.getModuleProvider = void 0;
var fs = require("fs");
var path = require("path");
var custom_module_1 = require("./custom_module");
var model_parser_1 = require("./model_parser");
var util_1 = require("./util");
function getModuleProvider(opts) {
return new ESMModuleProvider();
}
exports.getModuleProvider = getModuleProvider;
var ESMModuleProvider = /** @class */ (function () {
function ESMModuleProvider() {
}
/**
* Writes out custom tfjs module(s) to disk.
*/
ESMModuleProvider.prototype.produceCustomTFJSModule = function (config) {
var normalizedOutputPath = config.normalizedOutputPath;
var moduleStrs = (0, custom_module_1.getCustomModuleString)(config, exports.esmImportProvider);
fs.mkdirSync(normalizedOutputPath, { recursive: true });
console.log("Will write custom tfjs modules to ".concat(normalizedOutputPath));
var customTfjsFileName = 'custom_tfjs.js';
var customTfjsCoreFileName = 'custom_tfjs_core.js';
// Write a custom module for @tensorflow/tfjs and @tensorflow/tfjs-core
fs.writeFileSync(path.join(normalizedOutputPath, customTfjsCoreFileName), moduleStrs.core);
fs.writeFileSync(path.join(normalizedOutputPath, customTfjsFileName), moduleStrs.tfjs);
// Write a custom module tfjs-core ops used by converter executors
var kernelToOps;
var mappingPath;
try {
mappingPath =
require.resolve('@tensorflow/tfjs-converter/metadata/kernel2op.json');
kernelToOps = JSON.parse(fs.readFileSync(mappingPath, 'utf-8'));
}
catch (e) {
(0, util_1.bail)("Error loading kernel to ops mapping file ".concat(mappingPath));
}
var converterOps = (0, model_parser_1.getOpsForConfig)(config, kernelToOps);
if (converterOps.length > 0) {
var converterOpsModule = (0, custom_module_1.getCustomConverterOpsModule)(converterOps, exports.esmImportProvider);
var customConverterOpsFileName = 'custom_ops_for_converter.js';
fs.writeFileSync(path.join(normalizedOutputPath, customConverterOpsFileName), converterOpsModule);
}
};
return ESMModuleProvider;
}());
/**
* An import provider to generate custom esm modules.
*/
// Exported for tests.
exports.esmImportProvider = {
importCoreStr: function (forwardModeOnly) {
var importLines = [
"import {registerKernel} from '@tensorflow/tfjs-core/dist/base';",
"import '@tensorflow/tfjs-core/dist/base_side_effects';",
"export * from '@tensorflow/tfjs-core/dist/base';"
];
if (!forwardModeOnly) {
importLines.push("import {registerGradient} from '@tensorflow/tfjs-core/dist/base';");
}
return importLines.join('\n');
},
importConverterStr: function () {
return "export * from '@tensorflow/tfjs-converter';";
},
importBackendStr: function (backend) {
var backendPkg = getBackendPath(backend);
return "export * from '".concat(backendPkg, "/dist/base';");
},
importKernelStr: function (kernelName, backend) {
var backendPkg = getBackendPath(backend);
var kernelConfigId = "".concat(kernelName, "_").concat(backend);
var importPath = "".concat(backendPkg, "/dist/kernels/").concat(kernelName);
var importStatement = "import {".concat((0, util_1.kernelNameToVariableName)(kernelName), "Config as ").concat(kernelConfigId, "} from '").concat(importPath, "';");
return { importPath: importPath, importStatement: importStatement, kernelConfigId: kernelConfigId };
},
importGradientConfigStr: function (kernelName) {
var gradConfigId = "".concat((0, util_1.kernelNameToVariableName)(kernelName), "GradConfig");
var importPath = "@tensorflow/tfjs-core/dist/gradients/".concat(kernelName, "_grad");
var importStatement = "import {".concat(gradConfigId, "} from '").concat(importPath, "';");
return { importPath: importPath, importStatement: importStatement, gradConfigId: gradConfigId };
},
importOpForConverterStr: function (opSymbol) {
var opFileName = (0, util_1.opNameToFileName)(opSymbol);
return "export {".concat(opSymbol, "} from '@tensorflow/tfjs-core/dist/ops/").concat(opFileName, "';");
},
importNamespacedOpsForConverterStr: function (namespace, opSymbols) {
var result = [];
for (var _i = 0, opSymbols_1 = opSymbols; _i < opSymbols_1.length; _i++) {
var opSymbol = opSymbols_1[_i];
var opFileName = (0, util_1.opNameToFileName)(opSymbol);
var opAlias = "".concat(opSymbol, "_").concat(namespace);
result.push("import {".concat(opSymbol, " as ").concat(opAlias, "} from '@tensorflow/tfjs-core/dist/ops/").concat(namespace, "/").concat(opFileName, "';"));
}
result.push("export const ".concat(namespace, " = {"));
for (var _a = 0, opSymbols_2 = opSymbols; _a < opSymbols_2.length; _a++) {
var opSymbol = opSymbols_2[_a];
var opAlias = "".concat(opSymbol, "_").concat(namespace);
result.push("\t".concat(opSymbol, ": ").concat(opAlias, ","));
}
result.push("};");
return result.join('\n');
},
validateImportPath: function (importPath) {
try {
require.resolve(importPath);
return true;
}
catch (e) {
return false;
}
}
};
function getBackendPath(backend) {
switch (backend) {
case 'cpu':
return '@tensorflow/tfjs-backend-cpu';
case 'webgl':
return '@tensorflow/tfjs-backend-webgl';
case 'wasm':
return '@tensorflow/tfjs-backend-wasm';
default:
throw new Error("Unsupported backend ".concat(backend));
}
}
//# sourceMappingURL=esm_module_provider.js.map