react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
152 lines (143 loc) • 7.57 kB
JavaScript
;
import { DeeplabLabel, SelfieSegmentationLabel } from '../../types/semanticSegmentation';
import { IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision';
import { fetchModelPath, VisionLabeledModule } from './VisionLabeledModule';
const PascalVocSegmentationConfig = {
labelMap: DeeplabLabel,
preprocessorConfig: {
normMean: IMAGENET1K_MEAN,
normStd: IMAGENET1K_STD
}
};
const ModelConfigs = {
'deeplab-v3-resnet50': PascalVocSegmentationConfig,
'deeplab-v3-resnet101': PascalVocSegmentationConfig,
'deeplab-v3-mobilenet-v3-large': PascalVocSegmentationConfig,
'lraspp-mobilenet-v3-large': PascalVocSegmentationConfig,
'fcn-resnet50': PascalVocSegmentationConfig,
'fcn-resnet101': PascalVocSegmentationConfig,
'deeplab-v3-resnet50-quantized': PascalVocSegmentationConfig,
'deeplab-v3-resnet101-quantized': PascalVocSegmentationConfig,
'deeplab-v3-mobilenet-v3-large-quantized': PascalVocSegmentationConfig,
'lraspp-mobilenet-v3-large-quantized': PascalVocSegmentationConfig,
'fcn-resnet50-quantized': PascalVocSegmentationConfig,
'fcn-resnet101-quantized': PascalVocSegmentationConfig,
'selfie-segmentation': {
labelMap: SelfieSegmentationLabel,
preprocessorConfig: undefined
}
};
/** @internal */
/**
* Resolves the {@link LabelEnum} for a given built-in model name.
* @typeParam M - A built-in model name from {@link SemanticSegmentationModelName}.
* @category Types
*/
/** @internal */
/**
* Generic semantic segmentation module with type-safe label maps.
* Use a model name (e.g. `'deeplab-v3-resnet50'`) as the generic parameter for built-in models,
* or a custom label enum for custom configs.
* @typeParam T - Either a built-in model name (`'deeplab-v3-resnet50'`,
* `'deeplab-v3-resnet50-quantized'`, `'deeplab-v3-resnet101'`,
* `'deeplab-v3-resnet101-quantized'`, `'deeplab-v3-mobilenet-v3-large'`,
* `'deeplab-v3-mobilenet-v3-large-quantized'`, `'lraspp-mobilenet-v3-large'`,
* `'lraspp-mobilenet-v3-large-quantized'`, `'fcn-resnet50'`,
* `'fcn-resnet50-quantized'`, `'fcn-resnet101'`, `'fcn-resnet101-quantized'`,
* `'selfie-segmentation'`) or a custom {@link LabelEnum} label map.
* @category Typescript API
*/
export class SemanticSegmentationModule extends VisionLabeledModule {
constructor(labelMap, nativeModule) {
super(labelMap, nativeModule);
}
/**
* Creates a segmentation instance for a built-in model.
* The config object is discriminated by `modelName` — each model can require different fields.
* @param namedSources - A {@link SemanticSegmentationModelSources} object specifying which model to load and where to fetch it from.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the chosen model's label map.
* @example
* ```ts
* const segmentation = await SemanticSegmentationModule.fromModelName(DEEPLAB_V3_RESNET50);
* ```
*/
static async fromModelName(namedSources, onDownloadProgress = () => {}) {
const {
modelName,
modelSource
} = namedSources;
const {
labelMap
} = ModelConfigs[modelName];
const {
preprocessorConfig
} = ModelConfigs[modelName];
const normMean = preprocessorConfig?.normMean ?? [];
const normStd = preprocessorConfig?.normStd ?? [];
const allClassNames = Object.keys(labelMap).filter(k => isNaN(Number(k)));
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
const nativeModule = await global.loadSemanticSegmentation(modelPath, normMean, normStd, allClassNames);
return new SemanticSegmentationModule(labelMap, nativeModule);
}
/**
* Creates a segmentation instance with a user-provided model binary and label map.
* Use this when working with a custom-exported segmentation model that is not one of the built-in models.
* Internally uses `'custom'` as the model name for telemetry unless overridden.
*
* ## Required model contract
*
* The `.pte` model binary must expose a single `forward` method with the following interface:
*
* **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in
* `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`.
* H and W are read from the model's declared input shape at load time.
*
* **Output:** one `float32` tensor of shape `[1, C, H_out, W_out]` (NCHW) containing raw
* logits — one channel per class, in the same order as the entries in your `labelMap`.
* For binary segmentation a single-channel output is also supported: channel 0 is treated
* as the foreground probability and a synthetic background channel is added automatically.
*
* Preprocessing (resize → normalize) and postprocessing (softmax, argmax, resize back to
* original dimensions) are handled by the native runtime.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param config - A {@link SemanticSegmentationConfig} object with the label map and optional preprocessing parameters.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the provided label map.
* @example
* ```ts
* const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const;
* const segmentation = await SemanticSegmentationModule.fromCustomModel(
* 'https://example.com/custom_model.pte',
* { labelMap: MyLabels },
* );
* ```
*/
static async fromCustomModel(modelSource, config, onDownloadProgress = () => {}) {
const normMean = config.preprocessorConfig?.normMean ?? [];
const normStd = config.preprocessorConfig?.normStd ?? [];
const allClassNames = Object.keys(config.labelMap).filter(k => isNaN(Number(k)));
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
const nativeModule = await global.loadSemanticSegmentation(modelPath, normMean, normStd, allClassNames);
return new SemanticSegmentationModule(config.labelMap, nativeModule);
}
/**
* Executes the model's forward pass to perform semantic segmentation on the provided image.
*
* Supports two input types:
* 1. **String path/URI**: File path, URL, or Base64-encoded string
* 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage)
*
* **Note**: For VisionCamera frame processing, use `runOnFrame` instead.
* @param input - Image source (string or PixelData object)
* @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless.
* @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`.
* @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities.
* @throws {RnExecutorchError} If the model is not loaded.
*/
async forward(input, classesOfInterest = [], resizeToInput = true) {
const classesOfInterestNames = classesOfInterest.map(String);
return super.forward(input, classesOfInterestNames, resizeToInput);
}
}
//# sourceMappingURL=SemanticSegmentationModule.js.map