UNPKG

@tensorflow/tfjs

Version:

An open-source machine learning framework.

277 lines (229 loc) 9.83 kB
/** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {getCustomConverterOpsModule, getCustomModuleString} from './custom_module'; import {CustomTFJSBundleConfig, ImportProvider} from './types'; const mockImportProvider: ImportProvider = { importCoreStr: () => 'import CORE', importConverterStr: () => 'import CONVERTER', importBackendStr: (name: string) => `import BACKEND ${name}`, importKernelStr: (kernelName: string, backend: string) => { const importPath = `${ kernelName === 'Invalid' ? 'BACKEND_Invalid' : 'BACKEND'} ${backend}`; return { importPath, importStatement: `import KERNEL ${kernelName} from ${importPath}`, kernelConfigId: `${kernelName}_${backend}` }; }, importGradientConfigStr: (kernel: string) => { const importPath = kernel === 'Invalid' ? 'BACKEND_Invalid' : 'BACKEND'; return { importPath, importStatement: `import GRADIENT ${kernel} from ${importPath}`, gradConfigId: `${kernel}_GRAD_CONFIG`, }; }, importOpForConverterStr: (opSymbol: string) => { return `export * from ${opSymbol}`; }, importNamespacedOpsForConverterStr: ( namespace: string, opSymbols: string[]) => { return `export ${opSymbols.join(',')} as ${namespace} from ${namespace}/`; }, validateImportPath: (importPath: string) => { return !importPath.includes('Invalid'); } }; describe('getCustomModuleString forwardModeOnly=true', () => { const forwardModeOnly = true; it('one kernel, one backend', () => { const config = { kernels: ['MathKrnl', 'Invalid'], backends: ['FastBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs, core} = getCustomModuleString( // cast because FastBcknd is not a valid backend per the type config as CustomTFJSBundleConfig, mockImportProvider); expect(core).toContain('import CORE'); expect(tfjs).toContain('import CORE'); expect(tfjs).toContain('import BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)'); expect(tfjs).not.toContain('import KERNEL Invalid from BACKEND FastBcknd'); expect(tfjs).not.toContain('registerKernel(Invalid_FastBcknd)'); expect(tfjs).not.toContain('GRADIENT'); }); it('one kernel, one backend, one model', () => { const config = { kernels: ['MathKrnl'], backends: ['FastBcknd'], models: ['model1.json'], forwardModeOnly }; const {tfjs, core} = getCustomModuleString( // cast because FastBcknd is not a valid backend per the type config as CustomTFJSBundleConfig, mockImportProvider); expect(core).toContain('import CORE'); expect(tfjs).toContain('import CORE'); expect(tfjs).toContain('import CONVERTER'); expect(tfjs).toContain('import BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)'); expect(tfjs).not.toContain('GRADIENT'); }); it('one kernel, two backend', () => { const config = { kernels: ['MathKrnl'], backends: ['FastBcknd', 'SlowBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( // cast because the backends are not truly valid backend per the type config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import CORE'); expect(tfjs).toContain('import BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)'); expect(tfjs).toContain('import BACKEND SlowBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_SlowBcknd)'); expect(tfjs).not.toContain('GRADIENT'); }); it('two kernels, one backend', () => { const config = { kernels: ['MathKrnl', 'MathKrn2'], backends: ['FastBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import CORE'); expect(tfjs).toContain('import BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)'); expect(tfjs).toContain('registerKernel(MathKrn2_FastBcknd)'); expect(tfjs).not.toContain('GRADIENT'); }); it('two kernels, two backends', () => { const config = { kernels: ['MathKrnl', 'MathKrn2'], backends: ['FastBcknd', 'SlowBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import CORE'); expect(tfjs).toContain('import BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)'); expect(tfjs).toContain('registerKernel(MathKrn2_FastBcknd)'); expect(tfjs).toContain('import BACKEND SlowBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd'); expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND SlowBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_SlowBcknd)'); expect(tfjs).toContain('registerKernel(MathKrn2_SlowBcknd)'); expect(tfjs).not.toContain('GRADIENT'); }); }); describe('getCustomModuleString forwardModeOnly=false', () => { const forwardModeOnly = false; it('one kernel, one backend', () => { const config = { kernels: ['MathKrnl', 'Invalid'], backends: ['FastBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import CORE'); expect(tfjs).toContain('import BACKEND FastBcknd'); expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)'); expect(tfjs).not.toContain('import KERNEL Invalid from BACKEND FastBcknd'); expect(tfjs).not.toContain('registerKernel(Invalid_FastBcknd)'); expect(tfjs).toContain('import GRADIENT MathKrnl'); expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); expect(tfjs).not.toContain('import GRADIENT Invalid'); expect(tfjs).not.toContain('registerKernel(Invalid_GRAD_CONFIG)'); }); it('one kernel, two backend', () => { const config = { kernels: ['MathKrnl'], backends: ['FastBcknd', 'SlowBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import GRADIENT MathKrnl'); expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); const gradIndex = tfjs.indexOf('GRADIENT'); expect(tfjs.indexOf('GRADIENT', gradIndex + 1)) .toBe(-1, `Gradient import appears twice in:\n ${tfjs}`); }); it('two kernels, one backend', () => { const config = { kernels: ['MathKrnl', 'MathKrn2'], backends: ['FastBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import GRADIENT MathKrnl'); expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); expect(tfjs).toContain('import GRADIENT MathKrn2'); expect(tfjs).toContain('registerGradient(MathKrn2_GRAD_CONFIG)'); }); it('two kernels, two backends', () => { const config = { kernels: ['MathKrnl', 'MathKrn2'], backends: ['FastBcknd', 'SlowBcknd'], models: [] as string[], forwardModeOnly }; const {tfjs} = getCustomModuleString( config as CustomTFJSBundleConfig, mockImportProvider); expect(tfjs).toContain('import GRADIENT MathKrnl'); expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); expect(tfjs).toContain('import GRADIENT MathKrn2'); expect(tfjs).toContain('registerGradient(MathKrn2_GRAD_CONFIG)'); }); }); describe('getCustomConverterOpsModule', () => { it('non namespaced ops', () => { const result = getCustomConverterOpsModule(['add', 'sub'], mockImportProvider); expect(result).toContain('export * from add'); expect(result).toContain('export * from sub'); }); it('namespaced ops', () => { const result = getCustomConverterOpsModule( ['image.resizeBilinear', 'image.resizeNearestNeighbor'], mockImportProvider); expect(result).toContain( 'export resizeBilinear,resizeNearestNeighbor as image from image/'); }); });