@tensorflow/tfjs
Version:
An open-source machine learning framework.
277 lines (229 loc) • 9.83 kB
text/typescript
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {getCustomConverterOpsModule, getCustomModuleString} from './custom_module';
import {CustomTFJSBundleConfig, ImportProvider} from './types';
const mockImportProvider: ImportProvider = {
importCoreStr: () => 'import CORE',
importConverterStr: () => 'import CONVERTER',
importBackendStr: (name: string) => `import BACKEND ${name}`,
importKernelStr: (kernelName: string, backend: string) => {
const importPath = `${
kernelName === 'Invalid' ? 'BACKEND_Invalid' : 'BACKEND'} ${backend}`;
return {
importPath,
importStatement: `import KERNEL ${kernelName} from ${importPath}`,
kernelConfigId: `${kernelName}_${backend}`
};
},
importGradientConfigStr: (kernel: string) => {
const importPath = kernel === 'Invalid' ? 'BACKEND_Invalid' : 'BACKEND';
return {
importPath,
importStatement: `import GRADIENT ${kernel} from ${importPath}`,
gradConfigId: `${kernel}_GRAD_CONFIG`,
};
},
importOpForConverterStr: (opSymbol: string) => {
return `export * from ${opSymbol}`;
},
importNamespacedOpsForConverterStr: (
namespace: string, opSymbols: string[]) => {
return `export ${opSymbols.join(',')} as ${namespace} from ${namespace}/`;
},
validateImportPath: (importPath: string) => {
return !importPath.includes('Invalid');
}
};
describe('getCustomModuleString forwardModeOnly=true', () => {
const forwardModeOnly = true;
it('one kernel, one backend', () => {
const config = {
kernels: ['MathKrnl', 'Invalid'],
backends: ['FastBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs, core} = getCustomModuleString(
// cast because FastBcknd is not a valid backend per the type
config as CustomTFJSBundleConfig, mockImportProvider);
expect(core).toContain('import CORE');
expect(tfjs).toContain('import CORE');
expect(tfjs).toContain('import BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
expect(tfjs).not.toContain('import KERNEL Invalid from BACKEND FastBcknd');
expect(tfjs).not.toContain('registerKernel(Invalid_FastBcknd)');
expect(tfjs).not.toContain('GRADIENT');
});
it('one kernel, one backend, one model', () => {
const config = {
kernels: ['MathKrnl'],
backends: ['FastBcknd'],
models: ['model1.json'],
forwardModeOnly
};
const {tfjs, core} = getCustomModuleString(
// cast because FastBcknd is not a valid backend per the type
config as CustomTFJSBundleConfig, mockImportProvider);
expect(core).toContain('import CORE');
expect(tfjs).toContain('import CORE');
expect(tfjs).toContain('import CONVERTER');
expect(tfjs).toContain('import BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
expect(tfjs).not.toContain('GRADIENT');
});
it('one kernel, two backend', () => {
const config = {
kernels: ['MathKrnl'],
backends: ['FastBcknd', 'SlowBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
// cast because the backends are not truly valid backend per the type
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import CORE');
expect(tfjs).toContain('import BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
expect(tfjs).toContain('import BACKEND SlowBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_SlowBcknd)');
expect(tfjs).not.toContain('GRADIENT');
});
it('two kernels, one backend', () => {
const config = {
kernels: ['MathKrnl', 'MathKrn2'],
backends: ['FastBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import CORE');
expect(tfjs).toContain('import BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
expect(tfjs).toContain('registerKernel(MathKrn2_FastBcknd)');
expect(tfjs).not.toContain('GRADIENT');
});
it('two kernels, two backends', () => {
const config = {
kernels: ['MathKrnl', 'MathKrn2'],
backends: ['FastBcknd', 'SlowBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import CORE');
expect(tfjs).toContain('import BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
expect(tfjs).toContain('registerKernel(MathKrn2_FastBcknd)');
expect(tfjs).toContain('import BACKEND SlowBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd');
expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND SlowBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_SlowBcknd)');
expect(tfjs).toContain('registerKernel(MathKrn2_SlowBcknd)');
expect(tfjs).not.toContain('GRADIENT');
});
});
describe('getCustomModuleString forwardModeOnly=false', () => {
const forwardModeOnly = false;
it('one kernel, one backend', () => {
const config = {
kernels: ['MathKrnl', 'Invalid'],
backends: ['FastBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import CORE');
expect(tfjs).toContain('import BACKEND FastBcknd');
expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
expect(tfjs).not.toContain('import KERNEL Invalid from BACKEND FastBcknd');
expect(tfjs).not.toContain('registerKernel(Invalid_FastBcknd)');
expect(tfjs).toContain('import GRADIENT MathKrnl');
expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');
expect(tfjs).not.toContain('import GRADIENT Invalid');
expect(tfjs).not.toContain('registerKernel(Invalid_GRAD_CONFIG)');
});
it('one kernel, two backend', () => {
const config = {
kernels: ['MathKrnl'],
backends: ['FastBcknd', 'SlowBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import GRADIENT MathKrnl');
expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');
const gradIndex = tfjs.indexOf('GRADIENT');
expect(tfjs.indexOf('GRADIENT', gradIndex + 1))
.toBe(-1, `Gradient import appears twice in:\n ${tfjs}`);
});
it('two kernels, one backend', () => {
const config = {
kernels: ['MathKrnl', 'MathKrn2'],
backends: ['FastBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import GRADIENT MathKrnl');
expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');
expect(tfjs).toContain('import GRADIENT MathKrn2');
expect(tfjs).toContain('registerGradient(MathKrn2_GRAD_CONFIG)');
});
it('two kernels, two backends', () => {
const config = {
kernels: ['MathKrnl', 'MathKrn2'],
backends: ['FastBcknd', 'SlowBcknd'],
models: [] as string[],
forwardModeOnly
};
const {tfjs} = getCustomModuleString(
config as CustomTFJSBundleConfig, mockImportProvider);
expect(tfjs).toContain('import GRADIENT MathKrnl');
expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');
expect(tfjs).toContain('import GRADIENT MathKrn2');
expect(tfjs).toContain('registerGradient(MathKrn2_GRAD_CONFIG)');
});
});
describe('getCustomConverterOpsModule', () => {
it('non namespaced ops', () => {
const result =
getCustomConverterOpsModule(['add', 'sub'], mockImportProvider);
expect(result).toContain('export * from add');
expect(result).toContain('export * from sub');
});
it('namespaced ops', () => {
const result = getCustomConverterOpsModule(
['image.resizeBilinear', 'image.resizeNearestNeighbor'],
mockImportProvider);
expect(result).toContain(
'export resizeBilinear,resizeNearestNeighbor as image from image/');
});
});