UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

104 lines (97 loc) 4.52 kB
"use strict"; import { Imagenet1kLabel } from '../../constants/classification'; import { fetchModelPath, VisionLabeledModule } from './VisionLabeledModule'; const ModelConfigs = { 'efficientnet-v2-s': { labelMap: Imagenet1kLabel }, 'efficientnet-v2-s-quantized': { labelMap: Imagenet1kLabel } }; /** * Resolves the {@link LabelEnum} for a given built-in classification model name. * @typeParam M - A built-in model name from {@link ClassificationModelName}. * @category Types */ /** @internal */ /** * Generic classification module with type-safe label maps. * @typeParam T - Either a built-in model name (e.g. `'efficientnet-v2-s'`) * or a custom {@link LabelEnum} label map. * @category Typescript API */ export class ClassificationModule extends VisionLabeledModule { constructor(labelMap, nativeModule) { super(labelMap, nativeModule); } /** * Creates a classification instance for a built-in model. * @param namedSources - A {@link ClassificationModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `ClassificationModule` instance typed to the chosen model's label map. */ static async fromModelName(namedSources, onDownloadProgress = () => {}) { const { modelSource } = namedSources; const { labelMap, preprocessorConfig } = ModelConfigs[namedSources.modelName]; const normMean = preprocessorConfig?.normMean ?? []; const normStd = preprocessorConfig?.normStd ?? []; const allLabelNames = []; for (const [name, value] of Object.entries(labelMap)) { if (typeof value === 'number') allLabelNames[value] = name; } for (let i = 0; i < allLabelNames.length; i++) { if (allLabelNames[i] == null) allLabelNames[i] = ''; } const modelPath = await fetchModelPath(modelSource, onDownloadProgress); const nativeModule = await global.loadClassification(modelPath, normMean, normStd, allLabelNames); return new ClassificationModule(labelMap, nativeModule); } /** * Creates a classification instance with a user-provided model binary and label map. * Use this when working with a custom-exported model that is not one of the built-in presets. * * ## Required model contract * * The `.pte` model binary must expose a single `forward` method with the following interface: * * **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in * `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. * H and W are read from the model's declared input shape at load time. * * **Output:** one `float32` tensor of shape `[1, C]` containing raw logits — one value per class, * in the same order as the entries in your `labelMap`. Softmax is applied by the native runtime. * @param modelSource - A fetchable resource pointing to the model binary. * @param config - A {@link ClassificationConfig} object with the label map and optional preprocessing parameters. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `ClassificationModule` instance typed to the provided label map. */ static async fromCustomModel(modelSource, config, onDownloadProgress = () => {}) { const normMean = config.preprocessorConfig?.normMean ?? []; const normStd = config.preprocessorConfig?.normStd ?? []; const allLabelNames = []; for (const [name, value] of Object.entries(config.labelMap)) { if (typeof value === 'number') allLabelNames[value] = name; } for (let i = 0; i < allLabelNames.length; i++) { if (allLabelNames[i] == null) allLabelNames[i] = ''; } const modelPath = await fetchModelPath(modelSource, onDownloadProgress); const nativeModule = await global.loadClassification(modelPath, normMean, normStd, allLabelNames); return new ClassificationModule(config.labelMap, nativeModule); } /** * Executes the model's forward pass to classify the provided image. * @param input - A string image source (file path, URI, or Base64) or a {@link PixelData} object. * @returns A Promise resolving to an object mapping label keys to confidence scores. */ async forward(input) { return super.forward(input); } } //# sourceMappingURL=ClassificationModule.js.map