UNPKG

@vladmandic/face-api

Version:

FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS

44 lines (37 loc) 1.53 kB
import * as tf from '../../dist/tfjs.esm'; import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './types'; export function extractSeparableConvParamsFactory( extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[], ) { return (channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams => { const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]); const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]); const bias = tf.tensor1d(extractWeights(channelsOut)); paramMappings.push( { paramPath: `${mappedPrefix}/depthwise_filter` }, { paramPath: `${mappedPrefix}/pointwise_filter` }, { paramPath: `${mappedPrefix}/bias` }, ); return new SeparableConvParams( depthwise_filter, pointwise_filter, bias, ); }; } export function loadSeparableConvParamsFactory( // eslint-disable-next-line no-unused-vars extractWeightEntry: <T>(originalPath: string, paramRank: number) => T, ) { return (prefix: string): SeparableConvParams => { const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4); const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4); const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1); return new SeparableConvParams( depthwise_filter, pointwise_filter, bias, ); }; }