UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

68 lines (64 loc) 1.93 kB
import { PixelData } from '../..'; import { SemanticSegmentationModule, SegmentationLabels, } from '../../modules/computer_vision/SemanticSegmentationModule'; import { SemanticSegmentationProps, SemanticSegmentationType, ModelNameOf, SemanticSegmentationModelSources, } from '../../types/semanticSegmentation'; import { useModuleFactory } from '../useModuleFactory'; /** * React hook for managing a Semantic Segmentation model instance. * @typeParam C - A {@link SemanticSegmentationModelSources} config specifying which built-in model to load. * @param props - Configuration object containing `model` config and optional `preventLoad` flag. * @returns An object with model state (`error`, `isReady`, `isGenerating`, `downloadProgress`) and a typed `forward` function. * @example * ```ts * const { isReady, forward } = useSemanticSegmentation({ * model: { modelName: 'deeplab-v3-resnet50', modelSource: DEEPLAB_V3_RESNET50 }, * }); * ``` * @category Hooks */ export const useSemanticSegmentation = < C extends SemanticSegmentationModelSources, >({ model, preventLoad = false, }: SemanticSegmentationProps<C>): SemanticSegmentationType< SegmentationLabels<ModelNameOf<C>> > => { const { error, isReady, isGenerating, downloadProgress, runForward, runOnFrame, } = useModuleFactory({ factory: (config, onProgress) => SemanticSegmentationModule.fromModelName(config, onProgress), config: model, deps: [model.modelName, model.modelSource], preventLoad, }); const forward = <K extends keyof SegmentationLabels<ModelNameOf<C>>>( imageSource: string | PixelData, classesOfInterest: K[] = [], resizeToInput: boolean = true ) => runForward((inst) => inst.forward(imageSource, classesOfInterest, resizeToInput) ); return { error, isReady, isGenerating, downloadProgress, forward, runOnFrame, }; };