UNPKG

@vladmandic/face-api

Version:

FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS

149 lines (122 loc) 4.89 kB
import * as tf from '../../dist/tfjs.esm'; import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index'; import { isFloat } from '../utils/index'; import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types'; function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) { function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D { const weights = extractWeights(numFilterValues); const depth = weights.length / (numFilters * filterSize * filterSize); if (isFloat(depth)) { throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`); } return tf.tidy( () => tf.transpose( tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]), [2, 3, 1, 0], ), ); } function extractConvParams( numFilterValues: number, numFilters: number, filterSize: number, mappedPrefix: string, ): ConvParams { const filters = extractFilterValues(numFilterValues, numFilters, filterSize); const bias = tf.tensor1d(extractWeights(numFilters)); paramMappings.push( { paramPath: `${mappedPrefix}/filters` }, { paramPath: `${mappedPrefix}/bias` }, ); return { filters, bias }; } function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams { const weights = tf.tensor1d(extractWeights(numWeights)); const biases = tf.tensor1d(extractWeights(numWeights)); paramMappings.push( { paramPath: `${mappedPrefix}/weights` }, { paramPath: `${mappedPrefix}/biases` }, ); return { weights, biases, }; } function extractConvLayerParams( numFilterValues: number, numFilters: number, filterSize: number, mappedPrefix: string, ): ConvLayerParams { const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`); const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`); return { conv, scale }; } function extractResidualLayerParams( numFilterValues: number, numFilters: number, filterSize: number, mappedPrefix: string, isDown = false, ): ResidualLayerParams { const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`); const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`); return { conv1, conv2 }; } return { extractConvLayerParams, extractResidualLayerParams, }; } export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } { const { extractWeights, getRemainingWeights, } = extractWeightsFactory(weights); const paramMappings: ParamMapping[] = []; const { extractConvLayerParams, extractResidualLayerParams, } = extractorsFactory(extractWeights, paramMappings); const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down'); const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1'); const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2'); const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3'); const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true); const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1'); const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2'); const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3'); const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true); const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1'); const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2'); const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true); const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1'); const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2'); const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out'); const fc = tf.tidy( () => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0]), ); paramMappings.push({ paramPath: 'fc' }); if (getRemainingWeights().length !== 0) { throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`); } const params = { conv32_down, conv32_1, conv32_2, conv32_3, conv64_down, conv64_1, conv64_2, conv64_3, conv128_down, conv128_1, conv128_2, conv256_down, conv256_1, conv256_2, conv256_down_out, fc, }; return { params, paramMappings }; }