UNPKG

retinanetjs

Version:

Wrapper for models built using keras-retinanet.

52 lines (51 loc) 2.14 kB
import * as tf from '@tensorflow/tfjs'; import { AnchorParameters } from './anchors'; /** * Represents a detected object with coordinates being provided * as percentages of the image width and height. */ export interface DetectedObject { label: string; score: number; x1: number; x2: number; y1: number; y2: number; } /** * Represents a RetinaNet model. Rather than creating directly, * it is intended to be created using `load()`. */ export declare class RetinaNet { readonly model: tf.LayersModel; protected readonly classes: string[]; protected readonly preprocessingMode: string; protected readonly anchorParams: AnchorParameters; protected readonly height: number; protected readonly width: number; constructor(model: tf.LayersModel, classes: string[], preprocessingMode: string, anchorParams?: AnchorParameters); /** * Computes predictions. We currently do not support class-specific filtering. * When non-max suppression is applied, it will be across all boxes, regardless of class. * * @param img The image object on which to run object detection * @param threshold The prediction threshold * @param nmsThreshold The non-max suppresion IoU threshold */ detect(img: tf.Tensor3D | ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement, threshold?: number, nmsThreshold?: number): Promise<DetectedObject[]>; /** * Remove the model from memory. */ dispose(): void; private handleImageTensor; } /** * * @param modelPath The path to the model or a `tf.io.IOHandler` object * @param classes The list of detected classes * @param preprocessingMode One of `tf` or `caffe`. Check the `preprocess_images` * method of your backbone to see which you should use. * @param onProgress A callback to report progress * @param anchorParams The anchor parameters for your model */ export declare function load(modelPath: string | tf.io.IOHandler, classes: string[], preprocessingMode: string, onProgress?: (progress: number, message: string) => void, anchorParams?: AnchorParameters): Promise<RetinaNet>;