@vladmandic/face-api
Version:
FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS
152 lines (116 loc) • 5 kB
text/typescript
import * as tf from '../../dist/tfjs.esm';
import { Dimensions } from '../classes/Dimensions';
import { env } from '../env/index';
import { padToSquare } from '../ops/padToSquare';
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index';
import { createCanvasFromMedia } from './createCanvas';
import { imageToSquare } from './imageToSquare';
import { TResolvedNetInput } from './types';
export class NetInput {
private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = []
private _canvases: HTMLCanvasElement[] = []
private _batchSize: number
private _treatAsBatchInput = false
private _inputDimensions: number[][] = []
private _inputSize: number
constructor(inputs: Array<TResolvedNetInput>, treatAsBatchInput = false) {
if (!Array.isArray(inputs)) {
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
}
this._treatAsBatchInput = treatAsBatchInput;
this._batchSize = inputs.length;
inputs.forEach((input, idx) => {
if (isTensor3D(input)) {
this._imageTensors[idx] = input;
this._inputDimensions[idx] = input.shape;
return;
}
if (isTensor4D(input)) {
const batchSize = (input as any).shape[0];
if (batchSize !== 1) {
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
}
this._imageTensors[idx] = input;
this._inputDimensions[idx] = (input as any).shape.slice(1);
return;
}
const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input);
this._canvases[idx] = canvas;
this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
});
}
public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
return this._imageTensors;
}
public get canvases(): HTMLCanvasElement[] {
return this._canvases;
}
public get isBatchInput(): boolean {
return this.batchSize > 1 || this._treatAsBatchInput;
}
public get batchSize(): number {
return this._batchSize;
}
public get inputDimensions(): number[][] {
return this._inputDimensions;
}
public get inputSize(): number | undefined {
return this._inputSize;
}
public get reshapedInputDimensions(): Dimensions[] {
return range(this.batchSize, 0, 1).map(
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx),
);
}
public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
return this.canvases[batchIdx] || this.imageTensors[batchIdx];
}
public getInputDimensions(batchIdx: number): number[] {
return this._inputDimensions[batchIdx];
}
public getInputHeight(batchIdx: number): number {
return this._inputDimensions[batchIdx][0];
}
public getInputWidth(batchIdx: number): number {
return this._inputDimensions[batchIdx][1];
}
public getReshapedInputDimensions(batchIdx: number): Dimensions {
if (typeof this.inputSize !== 'number') {
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
}
const width = this.getInputWidth(batchIdx);
const height = this.getInputHeight(batchIdx);
return computeReshapedDimensions({ width, height }, this.inputSize);
}
/**
* Create a batch tensor from all input canvases and tensors
* with size [batchSize, inputSize, inputSize, 3].
*
* @param inputSize Height and width of the tensor.
* @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
* both sides of the minor dimension oof the image.
* @returns The batch tensor.
*/
public toBatchTensor(inputSize: number, isCenterInputs = true): tf.Tensor4D {
this._inputSize = inputSize;
return tf.tidy(() => {
const inputTensors = range(this.batchSize, 0, 1).map((batchIdx) => {
const input = this.getInput(batchIdx);
if (input instanceof tf.Tensor) {
let imgTensor = isTensor4D(input) ? input : tf.expandDims(input);
imgTensor = padToSquare(imgTensor, isCenterInputs);
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize], false, false);
}
return imgTensor.as3D(inputSize, inputSize, 3);
}
if (input instanceof env.getEnv().Canvas) {
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs));
}
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
});
const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
return batchTensor;
});
}
}