react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
181 lines (167 loc) • 6.64 kB
text/typescript
import { RnExecutorchError } from '../errors/errorUtils';
import { LabelEnum, Triple, ResourceSource, PixelData, Frame } from './common';
/**
* Configuration for a custom semantic segmentation model.
* @typeParam T - The {@link LabelEnum} type for the model.
* @property labelMap - The enum-like object mapping class names to indices.
* @property preprocessorConfig - Optional preprocessing parameters.
* @property preprocessorConfig.normMean - Per-channel mean values for input normalization.
* @property preprocessorConfig.normStd - Per-channel standard deviation values for input normalization.
* @category Types
*/
export type SemanticSegmentationConfig<T extends LabelEnum> = {
labelMap: T;
preprocessorConfig?: { normMean?: Triple<number>; normStd?: Triple<number> };
};
/**
* Per-model config for {@link SemanticSegmentationModule.fromModelName}.
* Each model name maps to its required fields.
* Add new union members here when a model needs extra sources or options.
* @category Types
*/
export type SemanticSegmentationModelSources =
| { modelName: 'deeplab-v3-resnet50'; modelSource: ResourceSource }
| { modelName: 'deeplab-v3-resnet101'; modelSource: ResourceSource }
| { modelName: 'deeplab-v3-mobilenet-v3-large'; modelSource: ResourceSource }
| { modelName: 'lraspp-mobilenet-v3-large'; modelSource: ResourceSource }
| { modelName: 'fcn-resnet50'; modelSource: ResourceSource }
| { modelName: 'fcn-resnet101'; modelSource: ResourceSource }
| { modelName: 'deeplab-v3-resnet50-quantized'; modelSource: ResourceSource }
| { modelName: 'deeplab-v3-resnet101-quantized'; modelSource: ResourceSource }
| {
modelName: 'deeplab-v3-mobilenet-v3-large-quantized';
modelSource: ResourceSource;
}
| {
modelName: 'lraspp-mobilenet-v3-large-quantized';
modelSource: ResourceSource;
}
| { modelName: 'fcn-resnet50-quantized'; modelSource: ResourceSource }
| { modelName: 'fcn-resnet101-quantized'; modelSource: ResourceSource }
| { modelName: 'selfie-segmentation'; modelSource: ResourceSource };
/**
* Union of all built-in semantic segmentation model names
* (e.g. `'deeplab-v3-resnet50'`, `'selfie-segmentation'`).
* @category Types
*/
export type SemanticSegmentationModelName =
SemanticSegmentationModelSources['modelName'];
/**
* Extracts the model name from a {@link SemanticSegmentationModelSources} config object.
* @category Types
*/
export type ModelNameOf<C extends SemanticSegmentationModelSources> =
C['modelName'];
/**
* Labels used in the DeepLab semantic segmentation model.
* @category Types
*/
export enum DeeplabLabel {
BACKGROUND,
AEROPLANE,
BICYCLE,
BIRD,
BOAT,
BOTTLE,
BUS,
CAR,
CAT,
CHAIR,
COW,
DININGTABLE,
DOG,
HORSE,
MOTORBIKE,
PERSON,
POTTEDPLANT,
SHEEP,
SOFA,
TRAIN,
TVMONITOR,
}
/**
* Labels used in the selfie semantic segmentation model.
* @category Types
*/
export enum SelfieSegmentationLabel {
SELFIE,
BACKGROUND,
}
/**
* Props for the `useSemanticSegmentation` hook.
* @typeParam C - A {@link SemanticSegmentationModelSources} config specifying which built-in model to load.
* @property model - The model config containing `modelName` and `modelSource`.
* @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
* @category Types
*/
export interface SemanticSegmentationProps<
C extends SemanticSegmentationModelSources,
> {
model: C;
preventLoad?: boolean;
}
/**
* Return type for the `useSemanticSegmentation` hook.
* Manages the state and operations for semantic segmentation models.
* @typeParam L - The {@link LabelEnum} representing the model's class labels.
* @category Types
*/
export interface SemanticSegmentationType<L extends LabelEnum> {
/**
* Contains the error object if the model failed to load, download, or encountered a runtime error during segmentation.
*/
error: RnExecutorchError | null;
/**
* Indicates whether the segmentation model is loaded and ready to process images.
*/
isReady: boolean;
/**
* Indicates whether the model is currently processing an image.
*/
isGenerating: boolean;
/**
* Represents the download progress of the model binary as a value between 0 and 1.
*/
downloadProgress: number;
/**
* Executes the model's forward pass to perform semantic segmentation on the provided image.
*
* Supports two input types:
* 1. **String path/URI**: File path, URL, or Base64-encoded string
* 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage)
*
* **Note**: For VisionCamera frame processing, use `runOnFrame` instead.
* @param input - Image source (string or PixelData object)
* @param classesOfInterest - An optional array of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless.
* @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`.
* @returns A Promise resolving to an object with an `'ARGMAX'` `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities.
* @throws {RnExecutorchError} If the model is not loaded or is currently processing another image.
*/
forward: <K extends keyof L>(
input: string | PixelData,
classesOfInterest?: K[],
resizeToInput?: boolean
) => Promise<Record<'ARGMAX', Int32Array> & Record<K, Float32Array>>;
/**
* Synchronous worklet function for real-time VisionCamera frame processing.
* Automatically handles native buffer extraction and cleanup.
*
* **Use this for VisionCamera frame processing in worklets.**
* For async processing, use `forward()` instead.
*
* Available after model is loaded (`isReady: true`).
* @param frame - VisionCamera Frame object
* @param isFrontCamera - Whether the front camera is active, used for mirroring corrections.
* @param classesOfInterest - Labels for which to return per-class probability masks.
* @param resizeToInput - Whether to resize masks to original frame dimensions. Defaults to `true`.
* @returns Object with `ARGMAX` Int32Array and per-class Float32Array masks.
*/
runOnFrame:
| ((
frame: Frame,
isFrontCamera: boolean,
classesOfInterest?: string[],
resizeToInput?: boolean
) => Record<'ARGMAX', Int32Array> & Record<string, Float32Array>)
| null;
}