UNPKG

@tensorflow/tfjs

Version:

An open-source machine learning framework.

173 lines (146 loc) 5.78 kB
/** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as fs from 'fs'; import * as path from 'path'; import {getCustomConverterOpsModule, getCustomModuleString} from './custom_module'; import {getOpsForConfig} from './model_parser'; import {CustomTFJSBundleConfig, ImportProvider, ModuleProvider, SupportedBackend} from './types'; import {bail, kernelNameToVariableName, opNameToFileName} from './util'; export function getModuleProvider(opts: {}): ModuleProvider { return new ESMModuleProvider(); } class ESMModuleProvider implements ModuleProvider { /** * Writes out custom tfjs module(s) to disk. */ produceCustomTFJSModule(config: CustomTFJSBundleConfig) { const {normalizedOutputPath} = config; const moduleStrs = getCustomModuleString(config, esmImportProvider); fs.mkdirSync(normalizedOutputPath, {recursive: true}); console.log(`Will write custom tfjs modules to ${normalizedOutputPath}`); const customTfjsFileName = 'custom_tfjs.js'; const customTfjsCoreFileName = 'custom_tfjs_core.js'; // Write a custom module for @tensorflow/tfjs and @tensorflow/tfjs-core fs.writeFileSync( path.join(normalizedOutputPath, customTfjsCoreFileName), moduleStrs.core); fs.writeFileSync( path.join(normalizedOutputPath, customTfjsFileName), moduleStrs.tfjs); // Write a custom module tfjs-core ops used by converter executors let kernelToOps; let mappingPath; try { mappingPath = require.resolve('@tensorflow/tfjs-converter/metadata/kernel2op.json'); kernelToOps = JSON.parse(fs.readFileSync(mappingPath, 'utf-8')); } catch (e) { bail(`Error loading kernel to ops mapping file ${mappingPath}`); } const converterOps = getOpsForConfig(config, kernelToOps); if (converterOps.length > 0) { const converterOpsModule = getCustomConverterOpsModule(converterOps, esmImportProvider); const customConverterOpsFileName = 'custom_ops_for_converter.js'; fs.writeFileSync( path.join(normalizedOutputPath, customConverterOpsFileName), converterOpsModule); } } } /** * An import provider to generate custom esm modules. */ // Exported for tests. export const esmImportProvider: ImportProvider = { importCoreStr(forwardModeOnly: boolean) { const importLines = [ `import {registerKernel} from '@tensorflow/tfjs-core/dist/base';`, `import '@tensorflow/tfjs-core/dist/base_side_effects';`, `export * from '@tensorflow/tfjs-core/dist/base';` ]; if (!forwardModeOnly) { importLines.push( `import {registerGradient} from '@tensorflow/tfjs-core/dist/base';`); } return importLines.join('\n'); }, importConverterStr() { return `export * from '@tensorflow/tfjs-converter';`; }, importBackendStr(backend: SupportedBackend) { const backendPkg = getBackendPath(backend); return `export * from '${backendPkg}/dist/base';`; }, importKernelStr(kernelName: string, backend: SupportedBackend) { const backendPkg = getBackendPath(backend); const kernelConfigId = `${kernelName}_${backend}`; const importPath = `${backendPkg}/dist/kernels/${kernelName}`; const importStatement = `import {${kernelNameToVariableName(kernelName)}Config as ${ kernelConfigId}} from '${importPath}';`; return {importPath, importStatement, kernelConfigId}; }, importGradientConfigStr(kernelName: string) { const gradConfigId = `${kernelNameToVariableName(kernelName)}GradConfig`; const importPath = `@tensorflow/tfjs-core/dist/gradients/${kernelName}_grad`; const importStatement = `import {${gradConfigId}} from '${importPath}';`; return {importPath, importStatement, gradConfigId}; }, importOpForConverterStr(opSymbol) { const opFileName = opNameToFileName(opSymbol); return `export {${opSymbol}} from '@tensorflow/tfjs-core/dist/ops/${ opFileName}';`; }, importNamespacedOpsForConverterStr(namespace, opSymbols) { const result: string[] = []; for (const opSymbol of opSymbols) { const opFileName = opNameToFileName(opSymbol); const opAlias = `${opSymbol}_${namespace}`; result.push(`import {${opSymbol} as ${ opAlias}} from '@tensorflow/tfjs-core/dist/ops/${namespace}/${ opFileName}';`); } result.push(`export const ${namespace} = {`); for (const opSymbol of opSymbols) { const opAlias = `${opSymbol}_${namespace}`; result.push(`\t${opSymbol}: ${opAlias},`); } result.push(`};`); return result.join('\n'); }, validateImportPath(importPath: string): boolean { try { require.resolve(importPath); return true; } catch (e) { return false; } } }; function getBackendPath(backend: SupportedBackend) { switch (backend) { case 'cpu': return '@tensorflow/tfjs-backend-cpu'; case 'webgl': return '@tensorflow/tfjs-backend-webgl'; case 'wasm': return '@tensorflow/tfjs-backend-wasm'; default: throw new Error(`Unsupported backend ${backend}`); } }