UNPKG

@vladmandic/face-api

Version:

FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS

97 lines (77 loc) 3.1 kB
import * as tf from '../../dist/tfjs.esm'; import { disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index'; import { isTensor2D } from '../utils/index'; import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types'; function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) { const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings); function extractScaleLayerParams(prefix: string): ScaleLayerParams { const weights = extractWeightEntry(`${prefix}/scale/weights`, 1); const biases = extractWeightEntry(`${prefix}/scale/biases`, 1); return { weights, biases }; } function extractConvLayerParams(prefix: string): ConvLayerParams { const filters = extractWeightEntry(`${prefix}/conv/filters`, 4); const bias = extractWeightEntry(`${prefix}/conv/bias`, 1); const scale = extractScaleLayerParams(prefix); return { conv: { filters, bias }, scale }; } function extractResidualLayerParams(prefix: string): ResidualLayerParams { return { conv1: extractConvLayerParams(`${prefix}/conv1`), conv2: extractConvLayerParams(`${prefix}/conv2`), }; } return { extractConvLayerParams, extractResidualLayerParams, }; } export function extractParamsFromWeightMap( weightMap: tf.NamedTensorMap, ): { params: NetParams, paramMappings: ParamMapping[] } { const paramMappings: ParamMapping[] = []; const { extractConvLayerParams, extractResidualLayerParams, } = extractorsFactory(weightMap, paramMappings); const conv32_down = extractConvLayerParams('conv32_down'); const conv32_1 = extractResidualLayerParams('conv32_1'); const conv32_2 = extractResidualLayerParams('conv32_2'); const conv32_3 = extractResidualLayerParams('conv32_3'); const conv64_down = extractResidualLayerParams('conv64_down'); const conv64_1 = extractResidualLayerParams('conv64_1'); const conv64_2 = extractResidualLayerParams('conv64_2'); const conv64_3 = extractResidualLayerParams('conv64_3'); const conv128_down = extractResidualLayerParams('conv128_down'); const conv128_1 = extractResidualLayerParams('conv128_1'); const conv128_2 = extractResidualLayerParams('conv128_2'); const conv256_down = extractResidualLayerParams('conv256_down'); const conv256_1 = extractResidualLayerParams('conv256_1'); const conv256_2 = extractResidualLayerParams('conv256_2'); const conv256_down_out = extractResidualLayerParams('conv256_down_out'); const { fc } = weightMap; paramMappings.push({ originalPath: 'fc', paramPath: 'fc' }); if (!isTensor2D(fc)) { throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`); } const params = { conv32_down, conv32_1, conv32_2, conv32_3, conv64_down, conv64_1, conv64_2, conv64_3, conv128_down, conv128_1, conv128_2, conv256_down, conv256_1, conv256_2, conv256_down_out, fc, }; disposeUnusedWeightTensors(weightMap, paramMappings); return { params, paramMappings }; }