UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

235 lines (219 loc) 11 kB
"use strict"; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; import { CocoLabel, CocoLabelYolo, IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision'; import { fetchModelPath, VisionLabeledModule } from './VisionLabeledModule'; const YOLO_DETECTION_CONFIG = { labelMap: CocoLabelYolo, preprocessorConfig: undefined, availableInputSizes: [384, 512, 640], defaultInputSize: 384, defaultDetectionThreshold: 0.5, defaultIouThreshold: 0.5 }; const ModelConfigs = { 'ssdlite-320-mobilenet-v3-large': { labelMap: CocoLabel, preprocessorConfig: undefined, availableInputSizes: undefined, defaultInputSize: undefined, defaultDetectionThreshold: 0.7, defaultIouThreshold: 0.55 }, 'rf-detr-nano': { labelMap: CocoLabel, preprocessorConfig: { normMean: IMAGENET1K_MEAN, normStd: IMAGENET1K_STD }, availableInputSizes: undefined, defaultInputSize: undefined, defaultDetectionThreshold: 0.7, defaultIouThreshold: 0.55 }, 'yolo26n': YOLO_DETECTION_CONFIG, 'yolo26s': YOLO_DETECTION_CONFIG, 'yolo26m': YOLO_DETECTION_CONFIG, 'yolo26l': YOLO_DETECTION_CONFIG, 'yolo26x': YOLO_DETECTION_CONFIG }; /** * Resolves the {@link LabelEnum} for a given built-in object detection model name. * @typeParam M - A built-in model name from {@link ObjectDetectionModelName}. * @category Types */ /** @internal */ /** * Generic object detection module with type-safe label maps. * @typeParam T - Either a built-in model name (e.g. `'ssdlite-320-mobilenet-v3-large'`) * or a custom {@link LabelEnum} label map. * @category Typescript API */ export class ObjectDetectionModule extends VisionLabeledModule { constructor(labelMap, modelConfig, nativeModule) { super(labelMap, nativeModule); this.modelConfig = modelConfig; } /** * Creates an object detection instance for a built-in model. * @param namedSources - A {@link ObjectDetectionModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to an `ObjectDetectionModule` instance typed to the chosen model's label map. */ static async fromModelName(namedSources, onDownloadProgress = () => {}) { const { modelSource } = namedSources; const modelConfig = ModelConfigs[namedSources.modelName]; const { labelMap, preprocessorConfig } = modelConfig; const normMean = preprocessorConfig?.normMean ?? []; const normStd = preprocessorConfig?.normStd ?? []; const allLabelNames = []; for (const [name, value] of Object.entries(labelMap)) { if (typeof value === 'number') allLabelNames[value] = name; } for (let i = 0; i < allLabelNames.length; i++) { if (allLabelNames[i] == null) allLabelNames[i] = ''; } const modelPath = await fetchModelPath(modelSource, onDownloadProgress); const nativeModule = await global.loadObjectDetection(modelPath, normMean, normStd, allLabelNames); return new ObjectDetectionModule(labelMap, modelConfig, nativeModule); } /** * Returns the available input sizes for this model, or undefined if the model accepts any size. * @returns An array of available input sizes, or undefined if not constrained. * @example * ```typescript * const sizes = model.getAvailableInputSizes(); // [384, 512, 640] for YOLO models, or undefined for RF-DETR * ``` */ getAvailableInputSizes() { return this.modelConfig.availableInputSizes; } /** * Override runOnFrame to provide an options-based API for VisionCamera integration. * @returns A worklet function for frame processing. * @throws {RnExecutorchError} If the underlying native worklet is unavailable (should not occur on a loaded module). */ get runOnFrame() { const baseRunOnFrame = super.runOnFrame; if (!baseRunOnFrame) { throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'Model is not loaded. Ensure the model has been loaded before using runOnFrame.'); } // Create reverse map (label → enum value) for classesOfInterest lookup const labelMap = {}; for (const [name, value] of Object.entries(this.labelMap)) { if (typeof value === 'number') { labelMap[name] = value; } } const defaultDetectionThreshold = this.modelConfig.defaultDetectionThreshold ?? 0.7; const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.55; const defaultInputSize = this.modelConfig.defaultInputSize; const availableInputSizes = this.modelConfig.availableInputSizes; return (frame, isFrontCamera, options) => { 'worklet'; const detectionThreshold = options?.detectionThreshold ?? defaultDetectionThreshold; const iouThreshold = options?.iouThreshold ?? defaultIouThreshold; const inputSize = options?.inputSize ?? defaultInputSize; if (availableInputSizes && inputSize !== undefined && !availableInputSizes.includes(inputSize)) { throw new Error(`Invalid inputSize: ${inputSize}. Available sizes: ${availableInputSizes.join(', ')}`); } const methodName = inputSize !== undefined ? `forward_${inputSize}` : 'forward'; const classIndices = options?.classesOfInterest ? options.classesOfInterest.map(label => { const labelStr = String(label); const enumValue = labelMap[labelStr]; return typeof enumValue === 'number' ? enumValue : -1; }) : []; return baseRunOnFrame(frame, isFrontCamera, detectionThreshold, iouThreshold, classIndices, methodName); }; } /** * Executes the model's forward pass to detect objects within the provided image. * * Supports two input types: * 1. **String path/URI**: File path, URL, or Base64-encoded string * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) * @param input - A string image source (file path, URI, or Base64) or a {@link PixelData} object. * @param options - Optional configuration for detection inference. Includes `detectionThreshold`, `inputSize`, and `classesOfInterest`. * @returns A Promise resolving to an array of {@link Detection} objects. * @throws {RnExecutorchError} If the model is not loaded or if an invalid `inputSize` is provided. * @example * ```typescript * const detections = await model.forward('path/to/image.jpg', { * detectionThreshold: 0.7, * inputSize: 640, // For YOLO models * classesOfInterest: ['PERSON', 'CAR'], * }); * ``` */ async forward(input, options) { if (this.nativeModule == null) { throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling forward().'); } // Extract parameters with defaults from config const detectionThreshold = options?.detectionThreshold ?? this.modelConfig.defaultDetectionThreshold ?? 0.7; const iouThreshold = options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.55; const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize; // Validate inputSize against availableInputSizes if (this.modelConfig.availableInputSizes && inputSize !== undefined && !this.modelConfig.availableInputSizes.includes(inputSize)) { throw new RnExecutorchError(RnExecutorchErrorCode.InvalidArgument, `Invalid inputSize: ${inputSize}. Available sizes: ${this.modelConfig.availableInputSizes.join(', ')}`); } // Construct method name: forward_384, forward_512, forward_640, or forward const methodName = inputSize !== undefined ? `forward_${inputSize}` : 'forward'; // Convert classesOfInterest to indices const classIndices = options?.classesOfInterest ? options.classesOfInterest.map(label => { const labelStr = String(label); const enumValue = this.labelMap[labelStr]; return typeof enumValue === 'number' ? enumValue : -1; }) : []; // Call native with all parameters return typeof input === 'string' ? await this.nativeModule.generateFromString(input, detectionThreshold, iouThreshold, classIndices, methodName) : await this.nativeModule.generateFromPixels(input, detectionThreshold, iouThreshold, classIndices, methodName); } /** * Creates an object detection instance with a user-provided model binary and label map. * Use this when working with a custom-exported model that is not one of the built-in presets. * Internally uses `'custom'` as the model name for telemetry unless overridden. * * ## Required model contract * * The `.pte` model binary must expose a single `forward` method with the following interface: * * **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in * `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. * H and W are read from the model's declared input shape at load time. * * **Outputs:** exactly three `float32` tensors, in this order: * 1. Bounding boxes — flat `[4·N]` array of `(x1, y1, x2, y2)` coordinates in model-input * pixel space, repeated for N detections. * 2. Confidence scores — flat `[N]` array of values in `[0, 1]`. * 3. Class indices — flat `[N]` array of `float32`-encoded integer class indices * (0-based, matching the order of entries in your `labelMap`). * * Preprocessing (resize → normalize) and postprocessing (coordinate rescaling, threshold * filtering, NMS) are handled by the native runtime — your model only needs to produce * the raw detections above. * @param modelSource - A fetchable resource pointing to the model binary. * @param config - A {@link ObjectDetectionConfig} object with the label map and optional preprocessing parameters. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to an `ObjectDetectionModule` instance typed to the provided label map. */ static async fromCustomModel(modelSource, config, onDownloadProgress = () => {}) { const normMean = config.preprocessorConfig?.normMean ?? []; const normStd = config.preprocessorConfig?.normStd ?? []; const allLabelNames = []; for (const [name, value] of Object.entries(config.labelMap)) { if (typeof value === 'number') allLabelNames[value] = name; } for (let i = 0; i < allLabelNames.length; i++) { if (allLabelNames[i] == null) allLabelNames[i] = ''; } const modelPath = await fetchModelPath(modelSource, onDownloadProgress); const nativeModule = await global.loadObjectDetection(modelPath, normMean, normStd, allLabelNames); return new ObjectDetectionModule(config.labelMap, config, nativeModule); } } //# sourceMappingURL=ObjectDetectionModule.js.map