@tensorflow/tfjs
Version:
An open-source machine learning framework.
173 lines (146 loc) • 5.78 kB
text/typescript
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as fs from 'fs';
import * as path from 'path';
import {getCustomConverterOpsModule, getCustomModuleString} from './custom_module';
import {getOpsForConfig} from './model_parser';
import {CustomTFJSBundleConfig, ImportProvider, ModuleProvider, SupportedBackend} from './types';
import {bail, kernelNameToVariableName, opNameToFileName} from './util';
export function getModuleProvider(opts: {}): ModuleProvider {
return new ESMModuleProvider();
}
class ESMModuleProvider implements ModuleProvider {
/**
* Writes out custom tfjs module(s) to disk.
*/
produceCustomTFJSModule(config: CustomTFJSBundleConfig) {
const {normalizedOutputPath} = config;
const moduleStrs = getCustomModuleString(config, esmImportProvider);
fs.mkdirSync(normalizedOutputPath, {recursive: true});
console.log(`Will write custom tfjs modules to ${normalizedOutputPath}`);
const customTfjsFileName = 'custom_tfjs.js';
const customTfjsCoreFileName = 'custom_tfjs_core.js';
// Write a custom module for @tensorflow/tfjs and @tensorflow/tfjs-core
fs.writeFileSync(
path.join(normalizedOutputPath, customTfjsCoreFileName),
moduleStrs.core);
fs.writeFileSync(
path.join(normalizedOutputPath, customTfjsFileName), moduleStrs.tfjs);
// Write a custom module tfjs-core ops used by converter executors
let kernelToOps;
let mappingPath;
try {
mappingPath =
require.resolve('@tensorflow/tfjs-converter/metadata/kernel2op.json');
kernelToOps = JSON.parse(fs.readFileSync(mappingPath, 'utf-8'));
} catch (e) {
bail(`Error loading kernel to ops mapping file ${mappingPath}`);
}
const converterOps = getOpsForConfig(config, kernelToOps);
if (converterOps.length > 0) {
const converterOpsModule =
getCustomConverterOpsModule(converterOps, esmImportProvider);
const customConverterOpsFileName = 'custom_ops_for_converter.js';
fs.writeFileSync(
path.join(normalizedOutputPath, customConverterOpsFileName),
converterOpsModule);
}
}
}
/**
* An import provider to generate custom esm modules.
*/
// Exported for tests.
export const esmImportProvider: ImportProvider = {
importCoreStr(forwardModeOnly: boolean) {
const importLines = [
`import {registerKernel} from '@tensorflow/tfjs-core/dist/base';`,
`import '@tensorflow/tfjs-core/dist/base_side_effects';`,
`export * from '@tensorflow/tfjs-core/dist/base';`
];
if (!forwardModeOnly) {
importLines.push(
`import {registerGradient} from '@tensorflow/tfjs-core/dist/base';`);
}
return importLines.join('\n');
},
importConverterStr() {
return `export * from '@tensorflow/tfjs-converter';`;
},
importBackendStr(backend: SupportedBackend) {
const backendPkg = getBackendPath(backend);
return `export * from '${backendPkg}/dist/base';`;
},
importKernelStr(kernelName: string, backend: SupportedBackend) {
const backendPkg = getBackendPath(backend);
const kernelConfigId = `${kernelName}_${backend}`;
const importPath = `${backendPkg}/dist/kernels/${kernelName}`;
const importStatement =
`import {${kernelNameToVariableName(kernelName)}Config as ${
kernelConfigId}} from '${importPath}';`;
return {importPath, importStatement, kernelConfigId};
},
importGradientConfigStr(kernelName: string) {
const gradConfigId = `${kernelNameToVariableName(kernelName)}GradConfig`;
const importPath =
`@tensorflow/tfjs-core/dist/gradients/${kernelName}_grad`;
const importStatement = `import {${gradConfigId}} from '${importPath}';`;
return {importPath, importStatement, gradConfigId};
},
importOpForConverterStr(opSymbol) {
const opFileName = opNameToFileName(opSymbol);
return `export {${opSymbol}} from '@tensorflow/tfjs-core/dist/ops/${
opFileName}';`;
},
importNamespacedOpsForConverterStr(namespace, opSymbols) {
const result: string[] = [];
for (const opSymbol of opSymbols) {
const opFileName = opNameToFileName(opSymbol);
const opAlias = `${opSymbol}_${namespace}`;
result.push(`import {${opSymbol} as ${
opAlias}} from '@tensorflow/tfjs-core/dist/ops/${namespace}/${
opFileName}';`);
}
result.push(`export const ${namespace} = {`);
for (const opSymbol of opSymbols) {
const opAlias = `${opSymbol}_${namespace}`;
result.push(`\t${opSymbol}: ${opAlias},`);
}
result.push(`};`);
return result.join('\n');
},
validateImportPath(importPath: string): boolean {
try {
require.resolve(importPath);
return true;
} catch (e) {
return false;
}
}
};
function getBackendPath(backend: SupportedBackend) {
switch (backend) {
case 'cpu':
return '@tensorflow/tfjs-backend-cpu';
case 'webgl':
return '@tensorflow/tfjs-backend-webgl';
case 'wasm':
return '@tensorflow/tfjs-backend-wasm';
default:
throw new Error(`Unsupported backend ${backend}`);
}
}