@vladmandic/face-api
Version:
FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS
149 lines (122 loc) • 4.89 kB
text/typescript
import * as tf from '../../dist/tfjs.esm';
import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
import { isFloat } from '../utils/index';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
const weights = extractWeights(numFilterValues);
const depth = weights.length / (numFilters * filterSize * filterSize);
if (isFloat(depth)) {
throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`);
}
return tf.tidy(
() => tf.transpose(
tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]),
[2, 3, 1, 0],
),
);
}
function extractConvParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string,
): ConvParams {
const filters = extractFilterValues(numFilterValues, numFilters, filterSize);
const bias = tf.tensor1d(extractWeights(numFilters));
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` },
);
return { filters, bias };
}
function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams {
const weights = tf.tensor1d(extractWeights(numWeights));
const biases = tf.tensor1d(extractWeights(numWeights));
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/biases` },
);
return {
weights,
biases,
};
}
function extractConvLayerParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string,
): ConvLayerParams {
const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`);
const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`);
return { conv, scale };
}
function extractResidualLayerParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string,
isDown = false,
): ResidualLayerParams {
const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`);
const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`);
return { conv1, conv2 };
}
return {
extractConvLayerParams,
extractResidualLayerParams,
};
}
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const {
extractWeights,
getRemainingWeights,
} = extractWeightsFactory(weights);
const paramMappings: ParamMapping[] = [];
const {
extractConvLayerParams,
extractResidualLayerParams,
} = extractorsFactory(extractWeights, paramMappings);
const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down');
const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1');
const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2');
const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3');
const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true);
const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1');
const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2');
const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3');
const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true);
const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1');
const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2');
const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true);
const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1');
const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2');
const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out');
const fc = tf.tidy(
() => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0]),
);
paramMappings.push({ paramPath: 'fc' });
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
}
const params = {
conv32_down,
conv32_1,
conv32_2,
conv32_3,
conv64_down,
conv64_1,
conv64_2,
conv64_3,
conv128_down,
conv128_1,
conv128_2,
conv256_down,
conv256_1,
conv256_2,
conv256_down_out,
fc,
};
return { params, paramMappings };
}