react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
104 lines (97 loc) • 4.52 kB
JavaScript
;
import { Imagenet1kLabel } from '../../constants/classification';
import { fetchModelPath, VisionLabeledModule } from './VisionLabeledModule';
const ModelConfigs = {
'efficientnet-v2-s': {
labelMap: Imagenet1kLabel
},
'efficientnet-v2-s-quantized': {
labelMap: Imagenet1kLabel
}
};
/**
* Resolves the {@link LabelEnum} for a given built-in classification model name.
* @typeParam M - A built-in model name from {@link ClassificationModelName}.
* @category Types
*/
/** @internal */
/**
* Generic classification module with type-safe label maps.
* @typeParam T - Either a built-in model name (e.g. `'efficientnet-v2-s'`)
* or a custom {@link LabelEnum} label map.
* @category Typescript API
*/
export class ClassificationModule extends VisionLabeledModule {
constructor(labelMap, nativeModule) {
super(labelMap, nativeModule);
}
/**
* Creates a classification instance for a built-in model.
* @param namedSources - A {@link ClassificationModelSources} object specifying which model to load and where to fetch it from.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `ClassificationModule` instance typed to the chosen model's label map.
*/
static async fromModelName(namedSources, onDownloadProgress = () => {}) {
const {
modelSource
} = namedSources;
const {
labelMap,
preprocessorConfig
} = ModelConfigs[namedSources.modelName];
const normMean = preprocessorConfig?.normMean ?? [];
const normStd = preprocessorConfig?.normStd ?? [];
const allLabelNames = [];
for (const [name, value] of Object.entries(labelMap)) {
if (typeof value === 'number') allLabelNames[value] = name;
}
for (let i = 0; i < allLabelNames.length; i++) {
if (allLabelNames[i] == null) allLabelNames[i] = '';
}
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
const nativeModule = await global.loadClassification(modelPath, normMean, normStd, allLabelNames);
return new ClassificationModule(labelMap, nativeModule);
}
/**
* Creates a classification instance with a user-provided model binary and label map.
* Use this when working with a custom-exported model that is not one of the built-in presets.
*
* ## Required model contract
*
* The `.pte` model binary must expose a single `forward` method with the following interface:
*
* **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in
* `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`.
* H and W are read from the model's declared input shape at load time.
*
* **Output:** one `float32` tensor of shape `[1, C]` containing raw logits — one value per class,
* in the same order as the entries in your `labelMap`. Softmax is applied by the native runtime.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param config - A {@link ClassificationConfig} object with the label map and optional preprocessing parameters.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `ClassificationModule` instance typed to the provided label map.
*/
static async fromCustomModel(modelSource, config, onDownloadProgress = () => {}) {
const normMean = config.preprocessorConfig?.normMean ?? [];
const normStd = config.preprocessorConfig?.normStd ?? [];
const allLabelNames = [];
for (const [name, value] of Object.entries(config.labelMap)) {
if (typeof value === 'number') allLabelNames[value] = name;
}
for (let i = 0; i < allLabelNames.length; i++) {
if (allLabelNames[i] == null) allLabelNames[i] = '';
}
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
const nativeModule = await global.loadClassification(modelPath, normMean, normStd, allLabelNames);
return new ClassificationModule(config.labelMap, nativeModule);
}
/**
* Executes the model's forward pass to classify the provided image.
* @param input - A string image source (file path, URI, or Base64) or a {@link PixelData} object.
* @returns A Promise resolving to an object mapping label keys to confidence scores.
*/
async forward(input) {
return super.forward(input);
}
}
//# sourceMappingURL=ClassificationModule.js.map