react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
468 lines (432 loc) • 17 kB
text/typescript
import {
ResourceSource,
LabelEnum,
PixelData,
Frame,
} from '../../types/common';
import {
InstanceSegmentationModelSources,
InstanceSegmentationConfig,
InstanceSegmentationModelName,
InstanceModelNameOf,
NativeSegmentedInstance,
SegmentedInstance,
InstanceSegmentationOptions,
} from '../../types/instanceSegmentation';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { RnExecutorchError } from '../../errors/errorUtils';
import {
fetchModelPath,
ResolveLabels as ResolveLabelsFor,
VisionLabeledModule,
} from './VisionLabeledModule';
import {
CocoLabel,
CocoLabelYolo,
IMAGENET1K_MEAN,
IMAGENET1K_STD,
} from '../../constants/commonVision';
const YOLO_SEG_CONFIG = {
preprocessorConfig: undefined,
labelMap: CocoLabelYolo,
availableInputSizes: [384, 512, 640] as const,
defaultInputSize: 384,
defaultConfidenceThreshold: 0.5,
defaultIouThreshold: 0.5,
postprocessorConfig: {
applyNMS: false,
},
} satisfies InstanceSegmentationConfig<typeof CocoLabelYolo>;
const RF_DETR_NANO_SEG_CONFIG = {
preprocessorConfig: { normMean: IMAGENET1K_MEAN, normStd: IMAGENET1K_STD },
labelMap: CocoLabel,
availableInputSizes: undefined,
defaultInputSize: undefined, //RFDetr exposes only one method named forward
defaultConfidenceThreshold: 0.5,
defaultIouThreshold: 0.5,
postprocessorConfig: {
applyNMS: true,
},
} satisfies InstanceSegmentationConfig<typeof CocoLabel>;
/**
* Builds a reverse map from 0-based model class index to label key name, and
* computes the minimum enum value (offset) so TS enum values can be converted
* to 0-based model indices.
* @param labelMap - The label enum to build the index map from.
* @returns An object containing `indexToLabel` map and `minValue` offset.
*/
function buildClassIndexMap(labelMap: LabelEnum): {
indexToLabel: Map<number, string>;
minValue: number;
} {
const entries: [string, number][] = [];
for (const [name, value] of Object.entries(labelMap)) {
if (typeof value === 'number') entries.push([name, value]);
}
const minValue = Math.min(...entries.map(([, v]) => v));
const indexToLabel = new Map<number, string>();
for (const [name, value] of entries) {
indexToLabel.set(value - minValue, name);
}
return { indexToLabel, minValue };
}
const ModelConfigs = {
'yolo26n-seg': YOLO_SEG_CONFIG,
'yolo26s-seg': YOLO_SEG_CONFIG,
'yolo26m-seg': YOLO_SEG_CONFIG,
'yolo26l-seg': YOLO_SEG_CONFIG,
'yolo26x-seg': YOLO_SEG_CONFIG,
'rfdetr-nano-seg': RF_DETR_NANO_SEG_CONFIG,
} as const satisfies Record<
InstanceSegmentationModelName,
| InstanceSegmentationConfig<typeof CocoLabel>
| InstanceSegmentationConfig<typeof CocoLabelYolo>
>;
/** @internal */
type ModelConfigsType = typeof ModelConfigs;
/**
* Resolves the label map type for a given built-in model name.
* @typeParam M - A built-in model name from {@link InstanceSegmentationModelName}.
* @category Types
*/
export type InstanceSegmentationLabels<
M extends InstanceSegmentationModelName,
> = ResolveLabels<M>;
/**
* Resolves the label type: if `T` is a {@link InstanceSegmentationModelName}, looks up its labels
* from the built-in config; otherwise uses `T` directly as a {@link LabelEnum}.
* @internal
*/
type ResolveLabels<T extends InstanceSegmentationModelName | LabelEnum> =
ResolveLabelsFor<T, ModelConfigsType>;
/**
* Generic instance segmentation module with type-safe label maps.
* Use a model name (e.g. `'yolo26n-seg'`) as the generic parameter for pre-configured models,
* or a custom label enum for custom configs.
*
* Supported models (download from HuggingFace):
* - `yolo26n-seg`, `yolo26s-seg`, `yolo26m-seg`, `yolo26l-seg`, `yolo26x-seg` - YOLO models with COCO labels (80 classes)
* - `rfdetr-nano-seg` - RF-DETR Nano model with COCO labels (80 classes)
* @typeParam T - Either a pre-configured model name from {@link InstanceSegmentationModelName}
* or a custom label map conforming to {@link LabelEnum}.
* @category Typescript API
* @example
* ```ts
* const segmentation = await InstanceSegmentationModule.fromModelName({
* modelName: 'yolo26n-seg',
* modelSource: 'https://huggingface.co/.../yolo26n-seg.pte',
* });
*
* const results = await segmentation.forward('path/to/image.jpg', {
* confidenceThreshold: 0.5,
* iouThreshold: 0.45,
* maxInstances: 20,
* inputSize: 640,
* });
* ```
*/
export class InstanceSegmentationModule<
T extends InstanceSegmentationModelName | LabelEnum,
> extends VisionLabeledModule<
SegmentedInstance<ResolveLabels<T>>[],
ResolveLabels<T>
> {
private modelConfig: InstanceSegmentationConfig<LabelEnum>;
private classIndexToLabel: Map<number, string>;
private labelEnumOffset: number;
private constructor(
labelMap: ResolveLabels<T>,
modelConfig: InstanceSegmentationConfig<LabelEnum>,
nativeModule: unknown,
classIndexToLabel: Map<number, string>,
labelEnumOffset: number
) {
super(labelMap, nativeModule);
this.modelConfig = modelConfig;
this.classIndexToLabel = classIndexToLabel;
this.labelEnumOffset = labelEnumOffset;
}
/**
* Creates an instance segmentation module for a pre-configured model.
* The config object is discriminated by `modelName` — each model can require different fields.
* @param config - A {@link InstanceSegmentationModelSources} object specifying which model to load and where to fetch it from.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to an `InstanceSegmentationModule` instance typed to the chosen model's label map.
* @example
* ```ts
* const segmentation = await InstanceSegmentationModule.fromModelName({
* modelName: 'yolo26n-seg',
* modelSource: 'https://huggingface.co/.../yolo26n-seg.pte',
* });
* ```
*/
static async fromModelName<C extends InstanceSegmentationModelSources>(
config: C,
onDownloadProgress: (progress: number) => void = () => {}
): Promise<InstanceSegmentationModule<InstanceModelNameOf<C>>> {
const { modelName, modelSource } = config;
const modelConfig = ModelConfigs[modelName as keyof typeof ModelConfigs];
const path = await fetchModelPath(modelSource, onDownloadProgress);
if (typeof global.loadInstanceSegmentation !== 'function') {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
`global.loadInstanceSegmentation is not available`
);
}
const { indexToLabel, minValue } = buildClassIndexMap(modelConfig.labelMap);
const nativeModule = await global.loadInstanceSegmentation(
path,
modelConfig.preprocessorConfig?.normMean || [],
modelConfig.preprocessorConfig?.normStd || [],
modelConfig.postprocessorConfig?.applyNMS ?? true
);
return new InstanceSegmentationModule<InstanceModelNameOf<C>>(
modelConfig.labelMap as ResolveLabels<InstanceModelNameOf<C>>,
modelConfig,
nativeModule,
indexToLabel,
minValue
);
}
/**
* Creates an instance segmentation module with a user-provided label map and custom config.
* Use this when working with a custom-exported segmentation model that is not one of the pre-configured models.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param config - A {@link InstanceSegmentationConfig} object with the label map and optional preprocessing parameters.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to an `InstanceSegmentationModule` instance typed to the provided label map.
* @example
* ```ts
* const MyLabels = { PERSON: 0, CAR: 1 } as const;
* const segmentation = await InstanceSegmentationModule.fromCustomModel(
* 'https://huggingface.co/.../custom_model.pte',
* {
* labelMap: MyLabels,
* availableInputSizes: [640],
* defaultInputSize: 640,
* defaultConfidenceThreshold: 0.5,
* defaultIouThreshold: 0.45,
* postprocessorConfig: { applyNMS: true },
* },
* );
* ```
*/
static async fromCustomModel<L extends LabelEnum>(
modelSource: ResourceSource,
config: InstanceSegmentationConfig<L>,
onDownloadProgress: (progress: number) => void = () => {}
): Promise<InstanceSegmentationModule<L>> {
const path = await fetchModelPath(modelSource, onDownloadProgress);
if (typeof global.loadInstanceSegmentation !== 'function') {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
`global.loadInstanceSegmentation is not available`
);
}
const { indexToLabel, minValue } = buildClassIndexMap(config.labelMap);
const nativeModule = await global.loadInstanceSegmentation(
path,
config.preprocessorConfig?.normMean || [],
config.preprocessorConfig?.normStd || [],
config.postprocessorConfig?.applyNMS ?? true
);
return new InstanceSegmentationModule<L>(
config.labelMap as ResolveLabels<L>,
config,
nativeModule,
indexToLabel,
minValue
);
}
/**
* Returns the available input sizes for this model, or undefined if the model accepts any size.
* @returns An array of available input sizes, or undefined if not constrained.
* @example
* ```ts
* const sizes = segmentation.getAvailableInputSizes();
* console.log(sizes); // [384, 512, 640] for YOLO models, or undefined for RF-DETR
* ```
*/
getAvailableInputSizes(): readonly number[] | undefined {
return this.modelConfig.availableInputSizes;
}
/**
* Override runOnFrame to add label mapping for VisionCamera integration.
* The parent's runOnFrame returns raw native results with class indices;
* this override maps them to label strings and provides an options-based API.
* @returns A worklet function for VisionCamera frame processing.
* @throws {RnExecutorchError} If the underlying native worklet is unavailable (should not occur on a loaded module).
*/
override get runOnFrame(): (
frame: Frame,
isFrontCamera: boolean,
options?: InstanceSegmentationOptions<ResolveLabels<T>>
) => SegmentedInstance<ResolveLabels<T>>[] {
const baseRunOnFrame = super.runOnFrame;
if (!baseRunOnFrame) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'Model is not loaded. Ensure the model has been loaded before using runOnFrame.'
);
}
// Convert Map to plain object for worklet serialization
const labelLookup: Record<number, string> = {};
this.classIndexToLabel.forEach((label, index) => {
labelLookup[index] = label;
});
// Create reverse map (label → enum value) for classesOfInterest lookup
const labelMap: Record<string, number> = {};
for (const [name, value] of Object.entries(this.labelMap)) {
if (typeof value === 'number') {
labelMap[name] = value;
}
}
const labelEnumOffset = this.labelEnumOffset;
const defaultConfidenceThreshold =
this.modelConfig.defaultConfidenceThreshold ?? 0.5;
const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.5;
const defaultInputSize = this.modelConfig.defaultInputSize;
return (
frame: Frame,
isFrontCamera: boolean,
options?: InstanceSegmentationOptions<ResolveLabels<T>>
): SegmentedInstance<ResolveLabels<T>>[] => {
'worklet';
const confidenceThreshold =
options?.confidenceThreshold ?? defaultConfidenceThreshold;
const iouThreshold = options?.iouThreshold ?? defaultIouThreshold;
const maxInstances = options?.maxInstances ?? 100;
const returnMaskAtOriginalResolution =
options?.returnMaskAtOriginalResolution ?? true;
const inputSize = options?.inputSize ?? defaultInputSize;
const methodName =
inputSize !== undefined ? `forward_${inputSize}` : 'forward';
const classIndices = options?.classesOfInterest
? options.classesOfInterest.map((label) => {
const labelStr = String(label);
const enumValue = labelMap[labelStr];
// Don't normalize - send raw enum values to match model output
return typeof enumValue === 'number' ? enumValue : -1;
})
: [];
const nativeResults = baseRunOnFrame(
frame,
isFrontCamera,
confidenceThreshold,
iouThreshold,
maxInstances,
classIndices,
returnMaskAtOriginalResolution,
methodName
);
return nativeResults.map((inst: any) => ({
bbox: inst.bbox,
mask: inst.mask,
maskWidth: inst.maskWidth,
maskHeight: inst.maskHeight,
label: (labelLookup[inst.classIndex - labelEnumOffset] ??
String(inst.classIndex)) as keyof ResolveLabels<T>,
score: inst.score,
}));
};
}
/**
* Executes the model's forward pass to perform instance segmentation on the provided image.
*
* Supports two input types:
* 1. **String path/URI**: File path, URL, or Base64-encoded string
* 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage)
* @param input - Image source (string path or PixelData object)
* @param options - Optional configuration for the segmentation process. Includes `confidenceThreshold`, `iouThreshold`, `maxInstances`, `classesOfInterest`, `returnMaskAtOriginalResolution`, and `inputSize`.
* @returns A Promise resolving to an array of {@link SegmentedInstance} objects with `bbox`, `mask`, `maskWidth`, `maskHeight`, `label`, `score`.
* @throws {RnExecutorchError} If the model is not loaded or if an invalid `inputSize` is provided.
* @example
* ```ts
* const results = await segmentation.forward('path/to/image.jpg', {
* confidenceThreshold: 0.6,
* iouThreshold: 0.5,
* maxInstances: 10,
* inputSize: 640,
* classesOfInterest: ['PERSON', 'CAR'],
* returnMaskAtOriginalResolution: true,
* });
*
* results.forEach((inst) => {
* console.log(`${inst.label}: ${(inst.score * 100).toFixed(1)}%`);
* });
* ```
*/
async forward(
input: string | PixelData,
options?: InstanceSegmentationOptions<ResolveLabels<T>>
): Promise<SegmentedInstance<ResolveLabels<T>>[]> {
if (this.nativeModule == null) {
throw new RnExecutorchError(
RnExecutorchErrorCode.ModuleNotLoaded,
'The model is currently not loaded.'
);
}
const confidenceThreshold =
options?.confidenceThreshold ??
this.modelConfig.defaultConfidenceThreshold ??
0.5;
const iouThreshold =
options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.5;
const maxInstances = options?.maxInstances ?? 100;
const returnMaskAtOriginalResolution =
options?.returnMaskAtOriginalResolution ?? true;
const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize;
if (
this.modelConfig.availableInputSizes &&
inputSize !== undefined &&
!this.modelConfig.availableInputSizes.includes(
inputSize as (typeof this.modelConfig.availableInputSizes)[number]
)
) {
throw new RnExecutorchError(
RnExecutorchErrorCode.InvalidArgument,
`Invalid inputSize: ${inputSize}. Available sizes: ${this.modelConfig.availableInputSizes.join(', ')}`
);
}
const methodName =
inputSize !== undefined ? `forward_${inputSize}` : 'forward';
const classIndices = options?.classesOfInterest
? options.classesOfInterest.map((label) => {
const labelStr = String(label);
const enumValue = this.labelMap[labelStr as keyof ResolveLabels<T>];
// Don't normalize - send raw enum values to match model output
return typeof enumValue === 'number' ? enumValue : -1;
})
: [];
const nativeResult: NativeSegmentedInstance[] =
typeof input === 'string'
? await this.nativeModule.generateFromString(
input,
confidenceThreshold,
iouThreshold,
maxInstances,
classIndices,
returnMaskAtOriginalResolution,
methodName
)
: await this.nativeModule.generateFromPixels(
input,
confidenceThreshold,
iouThreshold,
maxInstances,
classIndices,
returnMaskAtOriginalResolution,
methodName
);
return nativeResult.map((inst) => ({
bbox: inst.bbox,
mask: inst.mask,
maskWidth: inst.maskWidth,
maskHeight: inst.maskHeight,
label: (this.classIndexToLabel.get(
inst.classIndex - this.labelEnumOffset
) ?? String(inst.classIndex)) as keyof ResolveLabels<T>,
score: inst.score,
}));
}
}