retinanetjs
Version:
Wrapper for models built using keras-retinanet.
52 lines (51 loc) • 2.14 kB
TypeScript
import * as tf from '@tensorflow/tfjs';
import { AnchorParameters } from './anchors';
/**
* Represents a detected object with coordinates being provided
* as percentages of the image width and height.
*/
export interface DetectedObject {
label: string;
score: number;
x1: number;
x2: number;
y1: number;
y2: number;
}
/**
* Represents a RetinaNet model. Rather than creating directly,
* it is intended to be created using `load()`.
*/
export declare class RetinaNet {
readonly model: tf.LayersModel;
protected readonly classes: string[];
protected readonly preprocessingMode: string;
protected readonly anchorParams: AnchorParameters;
protected readonly height: number;
protected readonly width: number;
constructor(model: tf.LayersModel, classes: string[], preprocessingMode: string, anchorParams?: AnchorParameters);
/**
* Computes predictions. We currently do not support class-specific filtering.
* When non-max suppression is applied, it will be across all boxes, regardless of class.
*
* @param img The image object on which to run object detection
* @param threshold The prediction threshold
* @param nmsThreshold The non-max suppresion IoU threshold
*/
detect(img: tf.Tensor3D | ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement, threshold?: number, nmsThreshold?: number): Promise<DetectedObject[]>;
/**
* Remove the model from memory.
*/
dispose(): void;
private handleImageTensor;
}
/**
*
* @param modelPath The path to the model or a `tf.io.IOHandler` object
* @param classes The list of detected classes
* @param preprocessingMode One of `tf` or `caffe`. Check the `preprocess_images`
* method of your backbone to see which you should use.
* @param onProgress A callback to report progress
* @param anchorParams The anchor parameters for your model
*/
export declare function load(modelPath: string | tf.io.IOHandler, classes: string[], preprocessingMode: string, onProgress?: (progress: number, message: string) => void, anchorParams?: AnchorParameters): Promise<RetinaNet>;