react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
306 lines (293 loc) • 13.8 kB
JavaScript
;
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { RnExecutorchError } from '../../errors/errorUtils';
import { fetchModelPath, VisionLabeledModule } from './VisionLabeledModule';
import { CocoLabel, CocoLabelYolo, IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision';
const YOLO_SEG_CONFIG = {
preprocessorConfig: undefined,
labelMap: CocoLabelYolo,
availableInputSizes: [384, 512, 640],
defaultInputSize: 384,
defaultConfidenceThreshold: 0.5,
defaultIouThreshold: 0.5,
postprocessorConfig: {
applyNMS: false
}
};
const RF_DETR_NANO_SEG_CONFIG = {
preprocessorConfig: {
normMean: IMAGENET1K_MEAN,
normStd: IMAGENET1K_STD
},
labelMap: CocoLabel,
availableInputSizes: undefined,
defaultInputSize: undefined,
//RFDetr exposes only one method named forward
defaultConfidenceThreshold: 0.5,
defaultIouThreshold: 0.5,
postprocessorConfig: {
applyNMS: true
}
};
/**
* Builds a reverse map from 0-based model class index to label key name, and
* computes the minimum enum value (offset) so TS enum values can be converted
* to 0-based model indices.
* @param labelMap - The label enum to build the index map from.
* @returns An object containing `indexToLabel` map and `minValue` offset.
*/
function buildClassIndexMap(labelMap) {
const entries = [];
for (const [name, value] of Object.entries(labelMap)) {
if (typeof value === 'number') entries.push([name, value]);
}
const minValue = Math.min(...entries.map(([, v]) => v));
const indexToLabel = new Map();
for (const [name, value] of entries) {
indexToLabel.set(value - minValue, name);
}
return {
indexToLabel,
minValue
};
}
const ModelConfigs = {
'yolo26n-seg': YOLO_SEG_CONFIG,
'yolo26s-seg': YOLO_SEG_CONFIG,
'yolo26m-seg': YOLO_SEG_CONFIG,
'yolo26l-seg': YOLO_SEG_CONFIG,
'yolo26x-seg': YOLO_SEG_CONFIG,
'rfdetr-nano-seg': RF_DETR_NANO_SEG_CONFIG
};
/** @internal */
/**
* Resolves the label map type for a given built-in model name.
* @typeParam M - A built-in model name from {@link InstanceSegmentationModelName}.
* @category Types
*/
/**
* Resolves the label type: if `T` is a {@link InstanceSegmentationModelName}, looks up its labels
* from the built-in config; otherwise uses `T` directly as a {@link LabelEnum}.
* @internal
*/
/**
* Generic instance segmentation module with type-safe label maps.
* Use a model name (e.g. `'yolo26n-seg'`) as the generic parameter for pre-configured models,
* or a custom label enum for custom configs.
*
* Supported models (download from HuggingFace):
* - `yolo26n-seg`, `yolo26s-seg`, `yolo26m-seg`, `yolo26l-seg`, `yolo26x-seg` - YOLO models with COCO labels (80 classes)
* - `rfdetr-nano-seg` - RF-DETR Nano model with COCO labels (80 classes)
* @typeParam T - Either a pre-configured model name from {@link InstanceSegmentationModelName}
* or a custom label map conforming to {@link LabelEnum}.
* @category Typescript API
* @example
* ```ts
* const segmentation = await InstanceSegmentationModule.fromModelName({
* modelName: 'yolo26n-seg',
* modelSource: 'https://huggingface.co/.../yolo26n-seg.pte',
* });
*
* const results = await segmentation.forward('path/to/image.jpg', {
* confidenceThreshold: 0.5,
* iouThreshold: 0.45,
* maxInstances: 20,
* inputSize: 640,
* });
* ```
*/
export class InstanceSegmentationModule extends VisionLabeledModule {
constructor(labelMap, modelConfig, nativeModule, classIndexToLabel, labelEnumOffset) {
super(labelMap, nativeModule);
this.modelConfig = modelConfig;
this.classIndexToLabel = classIndexToLabel;
this.labelEnumOffset = labelEnumOffset;
}
/**
* Creates an instance segmentation module for a pre-configured model.
* The config object is discriminated by `modelName` — each model can require different fields.
* @param config - A {@link InstanceSegmentationModelSources} object specifying which model to load and where to fetch it from.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to an `InstanceSegmentationModule` instance typed to the chosen model's label map.
* @example
* ```ts
* const segmentation = await InstanceSegmentationModule.fromModelName({
* modelName: 'yolo26n-seg',
* modelSource: 'https://huggingface.co/.../yolo26n-seg.pte',
* });
* ```
*/
static async fromModelName(config, onDownloadProgress = () => {}) {
const {
modelName,
modelSource
} = config;
const modelConfig = ModelConfigs[modelName];
const path = await fetchModelPath(modelSource, onDownloadProgress);
if (typeof global.loadInstanceSegmentation !== 'function') {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, `global.loadInstanceSegmentation is not available`);
}
const {
indexToLabel,
minValue
} = buildClassIndexMap(modelConfig.labelMap);
const nativeModule = await global.loadInstanceSegmentation(path, modelConfig.preprocessorConfig?.normMean || [], modelConfig.preprocessorConfig?.normStd || [], modelConfig.postprocessorConfig?.applyNMS ?? true);
return new InstanceSegmentationModule(modelConfig.labelMap, modelConfig, nativeModule, indexToLabel, minValue);
}
/**
* Creates an instance segmentation module with a user-provided label map and custom config.
* Use this when working with a custom-exported segmentation model that is not one of the pre-configured models.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param config - A {@link InstanceSegmentationConfig} object with the label map and optional preprocessing parameters.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to an `InstanceSegmentationModule` instance typed to the provided label map.
* @example
* ```ts
* const MyLabels = { PERSON: 0, CAR: 1 } as const;
* const segmentation = await InstanceSegmentationModule.fromCustomModel(
* 'https://huggingface.co/.../custom_model.pte',
* {
* labelMap: MyLabels,
* availableInputSizes: [640],
* defaultInputSize: 640,
* defaultConfidenceThreshold: 0.5,
* defaultIouThreshold: 0.45,
* postprocessorConfig: { applyNMS: true },
* },
* );
* ```
*/
static async fromCustomModel(modelSource, config, onDownloadProgress = () => {}) {
const path = await fetchModelPath(modelSource, onDownloadProgress);
if (typeof global.loadInstanceSegmentation !== 'function') {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, `global.loadInstanceSegmentation is not available`);
}
const {
indexToLabel,
minValue
} = buildClassIndexMap(config.labelMap);
const nativeModule = await global.loadInstanceSegmentation(path, config.preprocessorConfig?.normMean || [], config.preprocessorConfig?.normStd || [], config.postprocessorConfig?.applyNMS ?? true);
return new InstanceSegmentationModule(config.labelMap, config, nativeModule, indexToLabel, minValue);
}
/**
* Returns the available input sizes for this model, or undefined if the model accepts any size.
* @returns An array of available input sizes, or undefined if not constrained.
* @example
* ```ts
* const sizes = segmentation.getAvailableInputSizes();
* console.log(sizes); // [384, 512, 640] for YOLO models, or undefined for RF-DETR
* ```
*/
getAvailableInputSizes() {
return this.modelConfig.availableInputSizes;
}
/**
* Override runOnFrame to add label mapping for VisionCamera integration.
* The parent's runOnFrame returns raw native results with class indices;
* this override maps them to label strings and provides an options-based API.
* @returns A worklet function for VisionCamera frame processing.
* @throws {RnExecutorchError} If the underlying native worklet is unavailable (should not occur on a loaded module).
*/
get runOnFrame() {
const baseRunOnFrame = super.runOnFrame;
if (!baseRunOnFrame) {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'Model is not loaded. Ensure the model has been loaded before using runOnFrame.');
}
// Convert Map to plain object for worklet serialization
const labelLookup = {};
this.classIndexToLabel.forEach((label, index) => {
labelLookup[index] = label;
});
// Create reverse map (label → enum value) for classesOfInterest lookup
const labelMap = {};
for (const [name, value] of Object.entries(this.labelMap)) {
if (typeof value === 'number') {
labelMap[name] = value;
}
}
const labelEnumOffset = this.labelEnumOffset;
const defaultConfidenceThreshold = this.modelConfig.defaultConfidenceThreshold ?? 0.5;
const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.5;
const defaultInputSize = this.modelConfig.defaultInputSize;
return (frame, isFrontCamera, options) => {
'worklet';
const confidenceThreshold = options?.confidenceThreshold ?? defaultConfidenceThreshold;
const iouThreshold = options?.iouThreshold ?? defaultIouThreshold;
const maxInstances = options?.maxInstances ?? 100;
const returnMaskAtOriginalResolution = options?.returnMaskAtOriginalResolution ?? true;
const inputSize = options?.inputSize ?? defaultInputSize;
const methodName = inputSize !== undefined ? `forward_${inputSize}` : 'forward';
const classIndices = options?.classesOfInterest ? options.classesOfInterest.map(label => {
const labelStr = String(label);
const enumValue = labelMap[labelStr];
// Don't normalize - send raw enum values to match model output
return typeof enumValue === 'number' ? enumValue : -1;
}) : [];
const nativeResults = baseRunOnFrame(frame, isFrontCamera, confidenceThreshold, iouThreshold, maxInstances, classIndices, returnMaskAtOriginalResolution, methodName);
return nativeResults.map(inst => ({
bbox: inst.bbox,
mask: inst.mask,
maskWidth: inst.maskWidth,
maskHeight: inst.maskHeight,
label: labelLookup[inst.classIndex - labelEnumOffset] ?? String(inst.classIndex),
score: inst.score
}));
};
}
/**
* Executes the model's forward pass to perform instance segmentation on the provided image.
*
* Supports two input types:
* 1. **String path/URI**: File path, URL, or Base64-encoded string
* 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage)
* @param input - Image source (string path or PixelData object)
* @param options - Optional configuration for the segmentation process. Includes `confidenceThreshold`, `iouThreshold`, `maxInstances`, `classesOfInterest`, `returnMaskAtOriginalResolution`, and `inputSize`.
* @returns A Promise resolving to an array of {@link SegmentedInstance} objects with `bbox`, `mask`, `maskWidth`, `maskHeight`, `label`, `score`.
* @throws {RnExecutorchError} If the model is not loaded or if an invalid `inputSize` is provided.
* @example
* ```ts
* const results = await segmentation.forward('path/to/image.jpg', {
* confidenceThreshold: 0.6,
* iouThreshold: 0.5,
* maxInstances: 10,
* inputSize: 640,
* classesOfInterest: ['PERSON', 'CAR'],
* returnMaskAtOriginalResolution: true,
* });
*
* results.forEach((inst) => {
* console.log(`${inst.label}: ${(inst.score * 100).toFixed(1)}%`);
* });
* ```
*/
async forward(input, options) {
if (this.nativeModule == null) {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded.');
}
const confidenceThreshold = options?.confidenceThreshold ?? this.modelConfig.defaultConfidenceThreshold ?? 0.5;
const iouThreshold = options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.5;
const maxInstances = options?.maxInstances ?? 100;
const returnMaskAtOriginalResolution = options?.returnMaskAtOriginalResolution ?? true;
const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize;
if (this.modelConfig.availableInputSizes && inputSize !== undefined && !this.modelConfig.availableInputSizes.includes(inputSize)) {
throw new RnExecutorchError(RnExecutorchErrorCode.InvalidArgument, `Invalid inputSize: ${inputSize}. Available sizes: ${this.modelConfig.availableInputSizes.join(', ')}`);
}
const methodName = inputSize !== undefined ? `forward_${inputSize}` : 'forward';
const classIndices = options?.classesOfInterest ? options.classesOfInterest.map(label => {
const labelStr = String(label);
const enumValue = this.labelMap[labelStr];
// Don't normalize - send raw enum values to match model output
return typeof enumValue === 'number' ? enumValue : -1;
}) : [];
const nativeResult = typeof input === 'string' ? await this.nativeModule.generateFromString(input, confidenceThreshold, iouThreshold, maxInstances, classIndices, returnMaskAtOriginalResolution, methodName) : await this.nativeModule.generateFromPixels(input, confidenceThreshold, iouThreshold, maxInstances, classIndices, returnMaskAtOriginalResolution, methodName);
return nativeResult.map(inst => ({
bbox: inst.bbox,
mask: inst.mask,
maskWidth: inst.maskWidth,
maskHeight: inst.maskHeight,
label: this.classIndexToLabel.get(inst.classIndex - this.labelEnumOffset) ?? String(inst.classIndex),
score: inst.score
}));
}
}
//# sourceMappingURL=InstanceSegmentationModule.js.map