@vladmandic/face-api
Version:
FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS
44 lines (37 loc) • 1.53 kB
text/typescript
import * as tf from '../../dist/tfjs.esm';
import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './types';
export function extractSeparableConvParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[],
) {
return (channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams => {
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
const bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push(
{ paramPath: `${mappedPrefix}/depthwise_filter` },
{ paramPath: `${mappedPrefix}/pointwise_filter` },
{ paramPath: `${mappedPrefix}/bias` },
);
return new SeparableConvParams(
depthwise_filter,
pointwise_filter,
bias,
);
};
}
export function loadSeparableConvParamsFactory(
// eslint-disable-next-line no-unused-vars
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T,
) {
return (prefix: string): SeparableConvParams => {
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4);
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4);
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
return new SeparableConvParams(
depthwise_filter,
pointwise_filter,
bias,
);
};
}