UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

152 lines (143 loc) 7.57 kB
"use strict"; import { DeeplabLabel, SelfieSegmentationLabel } from '../../types/semanticSegmentation'; import { IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision'; import { fetchModelPath, VisionLabeledModule } from './VisionLabeledModule'; const PascalVocSegmentationConfig = { labelMap: DeeplabLabel, preprocessorConfig: { normMean: IMAGENET1K_MEAN, normStd: IMAGENET1K_STD } }; const ModelConfigs = { 'deeplab-v3-resnet50': PascalVocSegmentationConfig, 'deeplab-v3-resnet101': PascalVocSegmentationConfig, 'deeplab-v3-mobilenet-v3-large': PascalVocSegmentationConfig, 'lraspp-mobilenet-v3-large': PascalVocSegmentationConfig, 'fcn-resnet50': PascalVocSegmentationConfig, 'fcn-resnet101': PascalVocSegmentationConfig, 'deeplab-v3-resnet50-quantized': PascalVocSegmentationConfig, 'deeplab-v3-resnet101-quantized': PascalVocSegmentationConfig, 'deeplab-v3-mobilenet-v3-large-quantized': PascalVocSegmentationConfig, 'lraspp-mobilenet-v3-large-quantized': PascalVocSegmentationConfig, 'fcn-resnet50-quantized': PascalVocSegmentationConfig, 'fcn-resnet101-quantized': PascalVocSegmentationConfig, 'selfie-segmentation': { labelMap: SelfieSegmentationLabel, preprocessorConfig: undefined } }; /** @internal */ /** * Resolves the {@link LabelEnum} for a given built-in model name. * @typeParam M - A built-in model name from {@link SemanticSegmentationModelName}. * @category Types */ /** @internal */ /** * Generic semantic segmentation module with type-safe label maps. * Use a model name (e.g. `'deeplab-v3-resnet50'`) as the generic parameter for built-in models, * or a custom label enum for custom configs. * @typeParam T - Either a built-in model name (`'deeplab-v3-resnet50'`, * `'deeplab-v3-resnet50-quantized'`, `'deeplab-v3-resnet101'`, * `'deeplab-v3-resnet101-quantized'`, `'deeplab-v3-mobilenet-v3-large'`, * `'deeplab-v3-mobilenet-v3-large-quantized'`, `'lraspp-mobilenet-v3-large'`, * `'lraspp-mobilenet-v3-large-quantized'`, `'fcn-resnet50'`, * `'fcn-resnet50-quantized'`, `'fcn-resnet101'`, `'fcn-resnet101-quantized'`, * `'selfie-segmentation'`) or a custom {@link LabelEnum} label map. * @category Typescript API */ export class SemanticSegmentationModule extends VisionLabeledModule { constructor(labelMap, nativeModule) { super(labelMap, nativeModule); } /** * Creates a segmentation instance for a built-in model. * The config object is discriminated by `modelName` — each model can require different fields. * @param namedSources - A {@link SemanticSegmentationModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the chosen model's label map. * @example * ```ts * const segmentation = await SemanticSegmentationModule.fromModelName(DEEPLAB_V3_RESNET50); * ``` */ static async fromModelName(namedSources, onDownloadProgress = () => {}) { const { modelName, modelSource } = namedSources; const { labelMap } = ModelConfigs[modelName]; const { preprocessorConfig } = ModelConfigs[modelName]; const normMean = preprocessorConfig?.normMean ?? []; const normStd = preprocessorConfig?.normStd ?? []; const allClassNames = Object.keys(labelMap).filter(k => isNaN(Number(k))); const modelPath = await fetchModelPath(modelSource, onDownloadProgress); const nativeModule = await global.loadSemanticSegmentation(modelPath, normMean, normStd, allClassNames); return new SemanticSegmentationModule(labelMap, nativeModule); } /** * Creates a segmentation instance with a user-provided model binary and label map. * Use this when working with a custom-exported segmentation model that is not one of the built-in models. * Internally uses `'custom'` as the model name for telemetry unless overridden. * * ## Required model contract * * The `.pte` model binary must expose a single `forward` method with the following interface: * * **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in * `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. * H and W are read from the model's declared input shape at load time. * * **Output:** one `float32` tensor of shape `[1, C, H_out, W_out]` (NCHW) containing raw * logits — one channel per class, in the same order as the entries in your `labelMap`. * For binary segmentation a single-channel output is also supported: channel 0 is treated * as the foreground probability and a synthetic background channel is added automatically. * * Preprocessing (resize → normalize) and postprocessing (softmax, argmax, resize back to * original dimensions) are handled by the native runtime. * @param modelSource - A fetchable resource pointing to the model binary. * @param config - A {@link SemanticSegmentationConfig} object with the label map and optional preprocessing parameters. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the provided label map. * @example * ```ts * const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; * const segmentation = await SemanticSegmentationModule.fromCustomModel( * 'https://example.com/custom_model.pte', * { labelMap: MyLabels }, * ); * ``` */ static async fromCustomModel(modelSource, config, onDownloadProgress = () => {}) { const normMean = config.preprocessorConfig?.normMean ?? []; const normStd = config.preprocessorConfig?.normStd ?? []; const allClassNames = Object.keys(config.labelMap).filter(k => isNaN(Number(k))); const modelPath = await fetchModelPath(modelSource, onDownloadProgress); const nativeModule = await global.loadSemanticSegmentation(modelPath, normMean, normStd, allClassNames); return new SemanticSegmentationModule(config.labelMap, nativeModule); } /** * Executes the model's forward pass to perform semantic segmentation on the provided image. * * Supports two input types: * 1. **String path/URI**: File path, URL, or Base64-encoded string * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) * * **Note**: For VisionCamera frame processing, use `runOnFrame` instead. * @param input - Image source (string or PixelData object) * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded. */ async forward(input, classesOfInterest = [], resizeToInput = true) { const classesOfInterestNames = classesOfInterest.map(String); return super.forward(input, classesOfInterestNames, resizeToInput); } } //# sourceMappingURL=SemanticSegmentationModule.js.map