UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

468 lines (432 loc) 17 kB
import { ResourceSource, LabelEnum, PixelData, Frame, } from '../../types/common'; import { InstanceSegmentationModelSources, InstanceSegmentationConfig, InstanceSegmentationModelName, InstanceModelNameOf, NativeSegmentedInstance, SegmentedInstance, InstanceSegmentationOptions, } from '../../types/instanceSegmentation'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; import { fetchModelPath, ResolveLabels as ResolveLabelsFor, VisionLabeledModule, } from './VisionLabeledModule'; import { CocoLabel, CocoLabelYolo, IMAGENET1K_MEAN, IMAGENET1K_STD, } from '../../constants/commonVision'; const YOLO_SEG_CONFIG = { preprocessorConfig: undefined, labelMap: CocoLabelYolo, availableInputSizes: [384, 512, 640] as const, defaultInputSize: 384, defaultConfidenceThreshold: 0.5, defaultIouThreshold: 0.5, postprocessorConfig: { applyNMS: false, }, } satisfies InstanceSegmentationConfig<typeof CocoLabelYolo>; const RF_DETR_NANO_SEG_CONFIG = { preprocessorConfig: { normMean: IMAGENET1K_MEAN, normStd: IMAGENET1K_STD }, labelMap: CocoLabel, availableInputSizes: undefined, defaultInputSize: undefined, //RFDetr exposes only one method named forward defaultConfidenceThreshold: 0.5, defaultIouThreshold: 0.5, postprocessorConfig: { applyNMS: true, }, } satisfies InstanceSegmentationConfig<typeof CocoLabel>; /** * Builds a reverse map from 0-based model class index to label key name, and * computes the minimum enum value (offset) so TS enum values can be converted * to 0-based model indices. * @param labelMap - The label enum to build the index map from. * @returns An object containing `indexToLabel` map and `minValue` offset. */ function buildClassIndexMap(labelMap: LabelEnum): { indexToLabel: Map<number, string>; minValue: number; } { const entries: [string, number][] = []; for (const [name, value] of Object.entries(labelMap)) { if (typeof value === 'number') entries.push([name, value]); } const minValue = Math.min(...entries.map(([, v]) => v)); const indexToLabel = new Map<number, string>(); for (const [name, value] of entries) { indexToLabel.set(value - minValue, name); } return { indexToLabel, minValue }; } const ModelConfigs = { 'yolo26n-seg': YOLO_SEG_CONFIG, 'yolo26s-seg': YOLO_SEG_CONFIG, 'yolo26m-seg': YOLO_SEG_CONFIG, 'yolo26l-seg': YOLO_SEG_CONFIG, 'yolo26x-seg': YOLO_SEG_CONFIG, 'rfdetr-nano-seg': RF_DETR_NANO_SEG_CONFIG, } as const satisfies Record< InstanceSegmentationModelName, | InstanceSegmentationConfig<typeof CocoLabel> | InstanceSegmentationConfig<typeof CocoLabelYolo> >; /** @internal */ type ModelConfigsType = typeof ModelConfigs; /** * Resolves the label map type for a given built-in model name. * @typeParam M - A built-in model name from {@link InstanceSegmentationModelName}. * @category Types */ export type InstanceSegmentationLabels< M extends InstanceSegmentationModelName, > = ResolveLabels<M>; /** * Resolves the label type: if `T` is a {@link InstanceSegmentationModelName}, looks up its labels * from the built-in config; otherwise uses `T` directly as a {@link LabelEnum}. * @internal */ type ResolveLabels<T extends InstanceSegmentationModelName | LabelEnum> = ResolveLabelsFor<T, ModelConfigsType>; /** * Generic instance segmentation module with type-safe label maps. * Use a model name (e.g. `'yolo26n-seg'`) as the generic parameter for pre-configured models, * or a custom label enum for custom configs. * * Supported models (download from HuggingFace): * - `yolo26n-seg`, `yolo26s-seg`, `yolo26m-seg`, `yolo26l-seg`, `yolo26x-seg` - YOLO models with COCO labels (80 classes) * - `rfdetr-nano-seg` - RF-DETR Nano model with COCO labels (80 classes) * @typeParam T - Either a pre-configured model name from {@link InstanceSegmentationModelName} * or a custom label map conforming to {@link LabelEnum}. * @category Typescript API * @example * ```ts * const segmentation = await InstanceSegmentationModule.fromModelName({ * modelName: 'yolo26n-seg', * modelSource: 'https://huggingface.co/.../yolo26n-seg.pte', * }); * * const results = await segmentation.forward('path/to/image.jpg', { * confidenceThreshold: 0.5, * iouThreshold: 0.45, * maxInstances: 20, * inputSize: 640, * }); * ``` */ export class InstanceSegmentationModule< T extends InstanceSegmentationModelName | LabelEnum, > extends VisionLabeledModule< SegmentedInstance<ResolveLabels<T>>[], ResolveLabels<T> > { private modelConfig: InstanceSegmentationConfig<LabelEnum>; private classIndexToLabel: Map<number, string>; private labelEnumOffset: number; private constructor( labelMap: ResolveLabels<T>, modelConfig: InstanceSegmentationConfig<LabelEnum>, nativeModule: unknown, classIndexToLabel: Map<number, string>, labelEnumOffset: number ) { super(labelMap, nativeModule); this.modelConfig = modelConfig; this.classIndexToLabel = classIndexToLabel; this.labelEnumOffset = labelEnumOffset; } /** * Creates an instance segmentation module for a pre-configured model. * The config object is discriminated by `modelName` — each model can require different fields. * @param config - A {@link InstanceSegmentationModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to an `InstanceSegmentationModule` instance typed to the chosen model's label map. * @example * ```ts * const segmentation = await InstanceSegmentationModule.fromModelName({ * modelName: 'yolo26n-seg', * modelSource: 'https://huggingface.co/.../yolo26n-seg.pte', * }); * ``` */ static async fromModelName<C extends InstanceSegmentationModelSources>( config: C, onDownloadProgress: (progress: number) => void = () => {} ): Promise<InstanceSegmentationModule<InstanceModelNameOf<C>>> { const { modelName, modelSource } = config; const modelConfig = ModelConfigs[modelName as keyof typeof ModelConfigs]; const path = await fetchModelPath(modelSource, onDownloadProgress); if (typeof global.loadInstanceSegmentation !== 'function') { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, `global.loadInstanceSegmentation is not available` ); } const { indexToLabel, minValue } = buildClassIndexMap(modelConfig.labelMap); const nativeModule = await global.loadInstanceSegmentation( path, modelConfig.preprocessorConfig?.normMean || [], modelConfig.preprocessorConfig?.normStd || [], modelConfig.postprocessorConfig?.applyNMS ?? true ); return new InstanceSegmentationModule<InstanceModelNameOf<C>>( modelConfig.labelMap as ResolveLabels<InstanceModelNameOf<C>>, modelConfig, nativeModule, indexToLabel, minValue ); } /** * Creates an instance segmentation module with a user-provided label map and custom config. * Use this when working with a custom-exported segmentation model that is not one of the pre-configured models. * @param modelSource - A fetchable resource pointing to the model binary. * @param config - A {@link InstanceSegmentationConfig} object with the label map and optional preprocessing parameters. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to an `InstanceSegmentationModule` instance typed to the provided label map. * @example * ```ts * const MyLabels = { PERSON: 0, CAR: 1 } as const; * const segmentation = await InstanceSegmentationModule.fromCustomModel( * 'https://huggingface.co/.../custom_model.pte', * { * labelMap: MyLabels, * availableInputSizes: [640], * defaultInputSize: 640, * defaultConfidenceThreshold: 0.5, * defaultIouThreshold: 0.45, * postprocessorConfig: { applyNMS: true }, * }, * ); * ``` */ static async fromCustomModel<L extends LabelEnum>( modelSource: ResourceSource, config: InstanceSegmentationConfig<L>, onDownloadProgress: (progress: number) => void = () => {} ): Promise<InstanceSegmentationModule<L>> { const path = await fetchModelPath(modelSource, onDownloadProgress); if (typeof global.loadInstanceSegmentation !== 'function') { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, `global.loadInstanceSegmentation is not available` ); } const { indexToLabel, minValue } = buildClassIndexMap(config.labelMap); const nativeModule = await global.loadInstanceSegmentation( path, config.preprocessorConfig?.normMean || [], config.preprocessorConfig?.normStd || [], config.postprocessorConfig?.applyNMS ?? true ); return new InstanceSegmentationModule<L>( config.labelMap as ResolveLabels<L>, config, nativeModule, indexToLabel, minValue ); } /** * Returns the available input sizes for this model, or undefined if the model accepts any size. * @returns An array of available input sizes, or undefined if not constrained. * @example * ```ts * const sizes = segmentation.getAvailableInputSizes(); * console.log(sizes); // [384, 512, 640] for YOLO models, or undefined for RF-DETR * ``` */ getAvailableInputSizes(): readonly number[] | undefined { return this.modelConfig.availableInputSizes; } /** * Override runOnFrame to add label mapping for VisionCamera integration. * The parent's runOnFrame returns raw native results with class indices; * this override maps them to label strings and provides an options-based API. * @returns A worklet function for VisionCamera frame processing. * @throws {RnExecutorchError} If the underlying native worklet is unavailable (should not occur on a loaded module). */ override get runOnFrame(): ( frame: Frame, isFrontCamera: boolean, options?: InstanceSegmentationOptions<ResolveLabels<T>> ) => SegmentedInstance<ResolveLabels<T>>[] { const baseRunOnFrame = super.runOnFrame; if (!baseRunOnFrame) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, 'Model is not loaded. Ensure the model has been loaded before using runOnFrame.' ); } // Convert Map to plain object for worklet serialization const labelLookup: Record<number, string> = {}; this.classIndexToLabel.forEach((label, index) => { labelLookup[index] = label; }); // Create reverse map (label → enum value) for classesOfInterest lookup const labelMap: Record<string, number> = {}; for (const [name, value] of Object.entries(this.labelMap)) { if (typeof value === 'number') { labelMap[name] = value; } } const labelEnumOffset = this.labelEnumOffset; const defaultConfidenceThreshold = this.modelConfig.defaultConfidenceThreshold ?? 0.5; const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.5; const defaultInputSize = this.modelConfig.defaultInputSize; return ( frame: Frame, isFrontCamera: boolean, options?: InstanceSegmentationOptions<ResolveLabels<T>> ): SegmentedInstance<ResolveLabels<T>>[] => { 'worklet'; const confidenceThreshold = options?.confidenceThreshold ?? defaultConfidenceThreshold; const iouThreshold = options?.iouThreshold ?? defaultIouThreshold; const maxInstances = options?.maxInstances ?? 100; const returnMaskAtOriginalResolution = options?.returnMaskAtOriginalResolution ?? true; const inputSize = options?.inputSize ?? defaultInputSize; const methodName = inputSize !== undefined ? `forward_${inputSize}` : 'forward'; const classIndices = options?.classesOfInterest ? options.classesOfInterest.map((label) => { const labelStr = String(label); const enumValue = labelMap[labelStr]; // Don't normalize - send raw enum values to match model output return typeof enumValue === 'number' ? enumValue : -1; }) : []; const nativeResults = baseRunOnFrame( frame, isFrontCamera, confidenceThreshold, iouThreshold, maxInstances, classIndices, returnMaskAtOriginalResolution, methodName ); return nativeResults.map((inst: any) => ({ bbox: inst.bbox, mask: inst.mask, maskWidth: inst.maskWidth, maskHeight: inst.maskHeight, label: (labelLookup[inst.classIndex - labelEnumOffset] ?? String(inst.classIndex)) as keyof ResolveLabels<T>, score: inst.score, })); }; } /** * Executes the model's forward pass to perform instance segmentation on the provided image. * * Supports two input types: * 1. **String path/URI**: File path, URL, or Base64-encoded string * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) * @param input - Image source (string path or PixelData object) * @param options - Optional configuration for the segmentation process. Includes `confidenceThreshold`, `iouThreshold`, `maxInstances`, `classesOfInterest`, `returnMaskAtOriginalResolution`, and `inputSize`. * @returns A Promise resolving to an array of {@link SegmentedInstance} objects with `bbox`, `mask`, `maskWidth`, `maskHeight`, `label`, `score`. * @throws {RnExecutorchError} If the model is not loaded or if an invalid `inputSize` is provided. * @example * ```ts * const results = await segmentation.forward('path/to/image.jpg', { * confidenceThreshold: 0.6, * iouThreshold: 0.5, * maxInstances: 10, * inputSize: 640, * classesOfInterest: ['PERSON', 'CAR'], * returnMaskAtOriginalResolution: true, * }); * * results.forEach((inst) => { * console.log(`${inst.label}: ${(inst.score * 100).toFixed(1)}%`); * }); * ``` */ async forward( input: string | PixelData, options?: InstanceSegmentationOptions<ResolveLabels<T>> ): Promise<SegmentedInstance<ResolveLabels<T>>[]> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded.' ); } const confidenceThreshold = options?.confidenceThreshold ?? this.modelConfig.defaultConfidenceThreshold ?? 0.5; const iouThreshold = options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.5; const maxInstances = options?.maxInstances ?? 100; const returnMaskAtOriginalResolution = options?.returnMaskAtOriginalResolution ?? true; const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize; if ( this.modelConfig.availableInputSizes && inputSize !== undefined && !this.modelConfig.availableInputSizes.includes( inputSize as (typeof this.modelConfig.availableInputSizes)[number] ) ) { throw new RnExecutorchError( RnExecutorchErrorCode.InvalidArgument, `Invalid inputSize: ${inputSize}. Available sizes: ${this.modelConfig.availableInputSizes.join(', ')}` ); } const methodName = inputSize !== undefined ? `forward_${inputSize}` : 'forward'; const classIndices = options?.classesOfInterest ? options.classesOfInterest.map((label) => { const labelStr = String(label); const enumValue = this.labelMap[labelStr as keyof ResolveLabels<T>]; // Don't normalize - send raw enum values to match model output return typeof enumValue === 'number' ? enumValue : -1; }) : []; const nativeResult: NativeSegmentedInstance[] = typeof input === 'string' ? await this.nativeModule.generateFromString( input, confidenceThreshold, iouThreshold, maxInstances, classIndices, returnMaskAtOriginalResolution, methodName ) : await this.nativeModule.generateFromPixels( input, confidenceThreshold, iouThreshold, maxInstances, classIndices, returnMaskAtOriginalResolution, methodName ); return nativeResult.map((inst) => ({ bbox: inst.bbox, mask: inst.mask, maskWidth: inst.maskWidth, maskHeight: inst.maskHeight, label: (this.classIndexToLabel.get( inst.classIndex - this.labelEnumOffset ) ?? String(inst.classIndex)) as keyof ResolveLabels<T>, score: inst.score, })); } }