react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
200 lines (190 loc) • 8.96 kB
text/typescript
import { ResourceSource, LabelEnum, PixelData } from '../../types/common';
import {
DeeplabLabel,
ModelNameOf,
SemanticSegmentationModelSources,
SemanticSegmentationConfig,
SemanticSegmentationModelName,
SelfieSegmentationLabel,
} from '../../types/semanticSegmentation';
import { IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision';
import {
fetchModelPath,
ResolveLabels as ResolveLabelsFor,
VisionLabeledModule,
} from './VisionLabeledModule';
const PascalVocSegmentationConfig = {
labelMap: DeeplabLabel,
preprocessorConfig: {
normMean: IMAGENET1K_MEAN,
normStd: IMAGENET1K_STD,
},
};
const ModelConfigs = {
'deeplab-v3-resnet50': PascalVocSegmentationConfig,
'deeplab-v3-resnet101': PascalVocSegmentationConfig,
'deeplab-v3-mobilenet-v3-large': PascalVocSegmentationConfig,
'lraspp-mobilenet-v3-large': PascalVocSegmentationConfig,
'fcn-resnet50': PascalVocSegmentationConfig,
'fcn-resnet101': PascalVocSegmentationConfig,
'deeplab-v3-resnet50-quantized': PascalVocSegmentationConfig,
'deeplab-v3-resnet101-quantized': PascalVocSegmentationConfig,
'deeplab-v3-mobilenet-v3-large-quantized': PascalVocSegmentationConfig,
'lraspp-mobilenet-v3-large-quantized': PascalVocSegmentationConfig,
'fcn-resnet50-quantized': PascalVocSegmentationConfig,
'fcn-resnet101-quantized': PascalVocSegmentationConfig,
'selfie-segmentation': {
labelMap: SelfieSegmentationLabel,
preprocessorConfig: undefined,
},
} as const satisfies Record<
SemanticSegmentationModelName,
SemanticSegmentationConfig<LabelEnum>
>;
/** @internal */
type ModelConfigsType = typeof ModelConfigs;
/**
* Resolves the {@link LabelEnum} for a given built-in model name.
* @typeParam M - A built-in model name from {@link SemanticSegmentationModelName}.
* @category Types
*/
export type SegmentationLabels<M extends SemanticSegmentationModelName> =
ModelConfigsType[M]['labelMap'];
/** @internal */
type ResolveLabels<T extends SemanticSegmentationModelName | LabelEnum> =
ResolveLabelsFor<T, ModelConfigsType>;
/**
* Generic semantic segmentation module with type-safe label maps.
* Use a model name (e.g. `'deeplab-v3-resnet50'`) as the generic parameter for built-in models,
* or a custom label enum for custom configs.
* @typeParam T - Either a built-in model name (`'deeplab-v3-resnet50'`,
* `'deeplab-v3-resnet50-quantized'`, `'deeplab-v3-resnet101'`,
* `'deeplab-v3-resnet101-quantized'`, `'deeplab-v3-mobilenet-v3-large'`,
* `'deeplab-v3-mobilenet-v3-large-quantized'`, `'lraspp-mobilenet-v3-large'`,
* `'lraspp-mobilenet-v3-large-quantized'`, `'fcn-resnet50'`,
* `'fcn-resnet50-quantized'`, `'fcn-resnet101'`, `'fcn-resnet101-quantized'`,
* `'selfie-segmentation'`) or a custom {@link LabelEnum} label map.
* @category Typescript API
*/
export class SemanticSegmentationModule<
T extends SemanticSegmentationModelName | LabelEnum,
> extends VisionLabeledModule<
Record<'ARGMAX', Int32Array> & Record<keyof ResolveLabels<T>, Float32Array>,
ResolveLabels<T>
> {
private constructor(labelMap: ResolveLabels<T>, nativeModule: unknown) {
super(labelMap, nativeModule);
}
/**
* Creates a segmentation instance for a built-in model.
* The config object is discriminated by `modelName` — each model can require different fields.
* @param namedSources - A {@link SemanticSegmentationModelSources} object specifying which model to load and where to fetch it from.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the chosen model's label map.
* @example
* ```ts
* const segmentation = await SemanticSegmentationModule.fromModelName(DEEPLAB_V3_RESNET50);
* ```
*/
static async fromModelName<C extends SemanticSegmentationModelSources>(
namedSources: C,
onDownloadProgress: (progress: number) => void = () => {}
): Promise<SemanticSegmentationModule<ModelNameOf<C>>> {
const { modelName, modelSource } = namedSources;
const { labelMap } = ModelConfigs[modelName];
const { preprocessorConfig } = ModelConfigs[
modelName
] as SemanticSegmentationConfig<LabelEnum>;
const normMean = preprocessorConfig?.normMean ?? [];
const normStd = preprocessorConfig?.normStd ?? [];
const allClassNames = Object.keys(labelMap).filter((k) => isNaN(Number(k)));
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
const nativeModule = await global.loadSemanticSegmentation(
modelPath,
normMean,
normStd,
allClassNames
);
return new SemanticSegmentationModule<ModelNameOf<C>>(
labelMap as ResolveLabels<ModelNameOf<C>>,
nativeModule
);
}
/**
* Creates a segmentation instance with a user-provided model binary and label map.
* Use this when working with a custom-exported segmentation model that is not one of the built-in models.
* Internally uses `'custom'` as the model name for telemetry unless overridden.
*
* ## Required model contract
*
* The `.pte` model binary must expose a single `forward` method with the following interface:
*
* **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in
* `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`.
* H and W are read from the model's declared input shape at load time.
*
* **Output:** one `float32` tensor of shape `[1, C, H_out, W_out]` (NCHW) containing raw
* logits — one channel per class, in the same order as the entries in your `labelMap`.
* For binary segmentation a single-channel output is also supported: channel 0 is treated
* as the foreground probability and a synthetic background channel is added automatically.
*
* Preprocessing (resize → normalize) and postprocessing (softmax, argmax, resize back to
* original dimensions) are handled by the native runtime.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param config - A {@link SemanticSegmentationConfig} object with the label map and optional preprocessing parameters.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the provided label map.
* @example
* ```ts
* const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const;
* const segmentation = await SemanticSegmentationModule.fromCustomModel(
* 'https://example.com/custom_model.pte',
* { labelMap: MyLabels },
* );
* ```
*/
static async fromCustomModel<L extends LabelEnum>(
modelSource: ResourceSource,
config: SemanticSegmentationConfig<L>,
onDownloadProgress: (progress: number) => void = () => {}
): Promise<SemanticSegmentationModule<L>> {
const normMean = config.preprocessorConfig?.normMean ?? [];
const normStd = config.preprocessorConfig?.normStd ?? [];
const allClassNames = Object.keys(config.labelMap).filter((k) =>
isNaN(Number(k))
);
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
const nativeModule = await global.loadSemanticSegmentation(
modelPath,
normMean,
normStd,
allClassNames
);
return new SemanticSegmentationModule<L>(
config.labelMap as ResolveLabels<L>,
nativeModule
);
}
/**
* Executes the model's forward pass to perform semantic segmentation on the provided image.
*
* Supports two input types:
* 1. **String path/URI**: File path, URL, or Base64-encoded string
* 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage)
*
* **Note**: For VisionCamera frame processing, use `runOnFrame` instead.
* @param input - Image source (string or PixelData object)
* @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless.
* @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`.
* @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities.
* @throws {RnExecutorchError} If the model is not loaded.
*/
override async forward<K extends keyof ResolveLabels<T>>(
input: string | PixelData,
classesOfInterest: K[] = [],
resizeToInput: boolean = true
): Promise<Record<'ARGMAX', Int32Array> & Record<K, Float32Array>> {
const classesOfInterestNames = classesOfInterest.map(String);
return super.forward(input, classesOfInterestNames, resizeToInput);
}
}