UNPKG

@vladmandic/face-api

Version:

JavaScript module for Face Detection and Face Recognition Using Tensorflow/JS

88 lines (72 loc) 3.27 kB
import * as tf from '@tensorflow/tfjs/dist/tf.es2017.js'; import { ConvParams } from '../common'; import { disposeUnusedWeightTensors } from '../common/disposeUnusedWeightTensors'; import { loadSeparableConvParamsFactory } from '../common/extractSeparableConvParamsFactory'; import { extractWeightEntryFactory } from '../common/extractWeightEntryFactory'; import { ParamMapping } from '../common/types'; import { TinyYolov2Config } from './config'; import { BatchNorm, ConvWithBatchNorm, TinyYolov2NetParams } from './types'; function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) { const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings) function extractBatchNormParams(prefix: string): BatchNorm { const sub = extractWeightEntry<tf.Tensor1D>(`${prefix}/sub`, 1) const truediv = extractWeightEntry<tf.Tensor1D>(`${prefix}/truediv`, 1) return { sub, truediv } } function extractConvParams(prefix: string): ConvParams { const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1) return { filters, bias } } function extractConvWithBatchNormParams(prefix: string): ConvWithBatchNorm { const conv = extractConvParams(`${prefix}/conv`) const bn = extractBatchNormParams(`${prefix}/bn`) return { conv, bn } } const extractSeparableConvParams = loadSeparableConvParamsFactory(extractWeightEntry) return { extractConvParams, extractConvWithBatchNormParams, extractSeparableConvParams } } export function extractParamsFromWeigthMap( weightMap: tf.NamedTensorMap, config: TinyYolov2Config ): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } { const paramMappings: ParamMapping[] = [] const { extractConvParams, extractConvWithBatchNormParams, extractSeparableConvParams } = extractorsFactory(weightMap, paramMappings) let params: TinyYolov2NetParams if (config.withSeparableConvs) { const numFilters = (config.filterSizes && config.filterSizes.length || 9) params = { conv0: config.isFirstLayerConv2d ? extractConvParams('conv0') : extractSeparableConvParams('conv0'), conv1: extractSeparableConvParams('conv1'), conv2: extractSeparableConvParams('conv2'), conv3: extractSeparableConvParams('conv3'), conv4: extractSeparableConvParams('conv4'), conv5: extractSeparableConvParams('conv5'), conv6: numFilters > 7 ? extractSeparableConvParams('conv6') : undefined, conv7: numFilters > 8 ? extractSeparableConvParams('conv7') : undefined, conv8: extractConvParams('conv8') } } else { params = { conv0: extractConvWithBatchNormParams('conv0'), conv1: extractConvWithBatchNormParams('conv1'), conv2: extractConvWithBatchNormParams('conv2'), conv3: extractConvWithBatchNormParams('conv3'), conv4: extractConvWithBatchNormParams('conv4'), conv5: extractConvWithBatchNormParams('conv5'), conv6: extractConvWithBatchNormParams('conv6'), conv7: extractConvWithBatchNormParams('conv7'), conv8: extractConvParams('conv8') } } disposeUnusedWeightTensors(weightMap, paramMappings) return { params, paramMappings } }