UNPKG

@tensorflow-models/body-pix

Version:

Pretrained BodyPix model in TensorFlow.js

993 lines (972 loc) 135 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tfconv from '@tensorflow/tfjs-converter'; import * as tf from '@tensorflow/tfjs-core'; import { getBackend } from '@tensorflow/tfjs-core'; /*! ***************************************************************************** Copyright (c) Microsoft Corporation. Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ***************************************************************************** */ /* global Reflect, Promise */ var extendStatics = function(d, b) { extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; return extendStatics(d, b); }; function __extends(d, b) { extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); } var __assign = function() { __assign = Object.assign || function __assign(t) { for (var s, i = 1, n = arguments.length; i < n; i++) { s = arguments[i]; for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p)) t[p] = s[p]; } return t; }; return __assign.apply(this, arguments); }; function __awaiter(thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); } function __generator(thisArg, body) { var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Takes the sigmoid of the part heatmap output and generates a 2d one-hot * tensor with ones where the part's score has the maximum value. * * @param partHeatmapScores */ function toFlattenedOneHotPartMap(partHeatmapScores) { var numParts = partHeatmapScores.shape[2]; var partMapLocations = tf.argMax(partHeatmapScores, 2); var partMapFlattened = tf.reshape(partMapLocations, [-1]); return tf.oneHot(partMapFlattened, numParts); } function clipByMask2d(image, mask) { return tf.mul(image, mask); } /** * Takes the sigmoid of the segmentation output, and generates a segmentation * mask with a 1 or 0 at each pixel where there is a person or not a person. The * segmentation threshold determines the threshold of a score for a pixel for it * to be considered part of a person. * @param segmentScores A 3d-tensor of the sigmoid of the segmentation output. * @param segmentationThreshold The minimum that segmentation values must have * to be considered part of the person. Affects the generation of the * segmentation mask and the clipping of the colored part image. * * @returns A segmentation mask with a 1 or 0 at each pixel where there is a * person or not a person. */ function toMaskTensor(segmentScores, threshold) { return tf.tidy(function () { return tf.cast(tf.greater(segmentScores, tf.scalar(threshold)), 'int32'); }); } /** * Takes the sigmoid of the person and part map output, and returns a 2d tensor * of an image with the corresponding value at each pixel corresponding to the * part with the highest value. These part ids are clipped by the segmentation * mask. Wherever the a pixel is clipped by the segmentation mask, its value * will set to -1, indicating that there is no part in that pixel. * @param segmentScores A 3d-tensor of the sigmoid of the segmentation output. * @param partHeatmapScores A 3d-tensor of the sigmoid of the part heatmap * output. The third dimension corresponds to the part. * * @returns A 2d tensor of an image with the corresponding value at each pixel * corresponding to the part with the highest value. These part ids are clipped * by the segmentation mask. It will have values of -1 for pixels that are * outside of the body and do not have a corresponding part. */ function decodePartSegmentation(segmentationMask, partHeatmapScores) { var _a = partHeatmapScores.shape, partMapHeight = _a[0], partMapWidth = _a[1], numParts = _a[2]; return tf.tidy(function () { var flattenedMap = toFlattenedOneHotPartMap(partHeatmapScores); var partNumbers = tf.expandDims(tf.range(0, numParts, 1, 'int32'), 1); var partMapFlattened = tf.cast(tf.matMul(flattenedMap, partNumbers), 'int32'); var partMap = tf.reshape(partMapFlattened, [partMapHeight, partMapWidth]); var partMapShiftedUpForClipping = tf.add(partMap, tf.scalar(1, 'int32')); return tf.sub(clipByMask2d(partMapShiftedUpForClipping, segmentationMask), tf.scalar(1, 'int32')); }); } function decodeOnlyPartSegmentation(partHeatmapScores) { var _a = partHeatmapScores.shape, partMapHeight = _a[0], partMapWidth = _a[1], numParts = _a[2]; return tf.tidy(function () { var flattenedMap = toFlattenedOneHotPartMap(partHeatmapScores); var partNumbers = tf.expandDims(tf.range(0, numParts, 1, 'int32'), 1); var partMapFlattened = tf.cast(tf.matMul(flattenedMap, partNumbers), 'int32'); return tf.reshape(partMapFlattened, [partMapHeight, partMapWidth]); }); } /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * BodyPix supports using various convolution neural network models * (e.g. ResNet and MobileNetV1) as its underlying base model. * The following BaseModel interface defines a unified interface for * creating such BodyPix base models. Currently both MobileNet (in * ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel * interface. New base models that conform to the BaseModel interface can be * added to BodyPix. */ var BaseModel = /** @class */ (function () { function BaseModel(model, outputStride) { this.model = model; this.outputStride = outputStride; var inputShape = this.model.inputs[0].shape; tf.util.assert((inputShape[1] === -1) && (inputShape[2] === -1), function () { return "Input shape [".concat(inputShape[1], ", ").concat(inputShape[2], "] ") + "must both be equal to or -1"; }); } /** * Predicts intermediate Tensor representations. * * @param input The input RGB image of the base model. * A Tensor of shape: [`inputResolution`, `inputResolution`, 3]. * * @return A dictionary of base model's intermediate predictions. * The returned dictionary should contains the following elements: * - heatmapScores: A Tensor3D that represents the keypoint heatmap scores. * - offsets: A Tensor3D that represents the offsets. * - displacementFwd: A Tensor3D that represents the forward displacement. * - displacementBwd: A Tensor3D that represents the backward displacement. * - segmentation: A Tensor3D that represents the segmentation of all * people. * - longOffsets: A Tensor3D that represents the long offsets used for * instance grouping. * - partHeatmaps: A Tensor3D that represents the body part segmentation. */ BaseModel.prototype.predict = function (input) { var _this = this; return tf.tidy(function () { var asFloat = _this.preprocessInput(tf.cast(input, 'float32')); var asBatch = tf.expandDims(asFloat, 0); var results = _this.model.predict(asBatch); var results3d = results.map(function (y) { return tf.squeeze(y, [0]); }); var namedResults = _this.nameOutputResults(results3d); return { heatmapScores: tf.sigmoid(namedResults.heatmap), offsets: namedResults.offsets, displacementFwd: namedResults.displacementFwd, displacementBwd: namedResults.displacementBwd, segmentation: namedResults.segmentation, partHeatmaps: namedResults.partHeatmaps, longOffsets: namedResults.longOffsets, partOffsets: namedResults.partOffsets }; }); }; /** * Releases the CPU and GPU memory allocated by the model. */ BaseModel.prototype.dispose = function () { this.model.dispose(); }; return BaseModel; }()); /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MobileNet = /** @class */ (function (_super) { __extends(MobileNet, _super); function MobileNet() { return _super !== null && _super.apply(this, arguments) || this; } MobileNet.prototype.preprocessInput = function (input) { // Normalize the pixels [0, 255] to be between [-1, 1]. return tf.tidy(function () { return tf.sub(tf.div(input, 127.5), 1.0); }); }; MobileNet.prototype.nameOutputResults = function (results) { var offsets = results[0], segmentation = results[1], partHeatmaps = results[2], longOffsets = results[3], heatmap = results[4], displacementFwd = results[5], displacementBwd = results[6], partOffsets = results[7]; return { offsets: offsets, segmentation: segmentation, partHeatmaps: partHeatmaps, longOffsets: longOffsets, heatmap: heatmap, displacementFwd: displacementFwd, displacementBwd: displacementBwd, partOffsets: partOffsets }; }; return MobileNet; }(BaseModel)); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PART_NAMES = [ 'nose', 'leftEye', 'rightEye', 'leftEar', 'rightEar', 'leftShoulder', 'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist', 'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle' ]; var NUM_KEYPOINTS = PART_NAMES.length; var PART_IDS = PART_NAMES.reduce(function (result, jointName, i) { result[jointName] = i; return result; }, {}); var CONNECTED_PART_NAMES = [ ['leftHip', 'leftShoulder'], ['leftElbow', 'leftShoulder'], ['leftElbow', 'leftWrist'], ['leftHip', 'leftKnee'], ['leftKnee', 'leftAnkle'], ['rightHip', 'rightShoulder'], ['rightElbow', 'rightShoulder'], ['rightElbow', 'rightWrist'], ['rightHip', 'rightKnee'], ['rightKnee', 'rightAnkle'], ['leftShoulder', 'rightShoulder'], ['leftHip', 'rightHip'] ]; /* * Define the skeleton. This defines the parent->child relationships of our * tree. Arbitrarily this defines the nose as the root of the tree, however * since we will infer the displacement for both parent->child and * child->parent, we can define the tree root as any node. */ var POSE_CHAIN = [ ['nose', 'leftEye'], ['leftEye', 'leftEar'], ['nose', 'rightEye'], ['rightEye', 'rightEar'], ['nose', 'leftShoulder'], ['leftShoulder', 'leftElbow'], ['leftElbow', 'leftWrist'], ['leftShoulder', 'leftHip'], ['leftHip', 'leftKnee'], ['leftKnee', 'leftAnkle'], ['nose', 'rightShoulder'], ['rightShoulder', 'rightElbow'], ['rightElbow', 'rightWrist'], ['rightShoulder', 'rightHip'], ['rightHip', 'rightKnee'], ['rightKnee', 'rightAnkle'] ]; CONNECTED_PART_NAMES.map(function (_a) { var jointNameA = _a[0], jointNameB = _a[1]; return ([PART_IDS[jointNameA], PART_IDS[jointNameB]]); }); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function getScale(_a, _b, padding) { var height = _a[0], width = _a[1]; var inputResolutionY = _b[0], inputResolutionX = _b[1]; var padT = padding.top, padB = padding.bottom, padL = padding.left, padR = padding.right; var scaleY = inputResolutionY / (padT + padB + height); var scaleX = inputResolutionX / (padL + padR + width); return [scaleX, scaleY]; } function getOffsetPoint(y, x, keypoint, offsets) { return { y: offsets.get(y, x, keypoint), x: offsets.get(y, x, keypoint + NUM_KEYPOINTS) }; } function getImageCoords(part, outputStride, offsets) { var heatmapY = part.heatmapY, heatmapX = part.heatmapX, keypoint = part.id; var _a = getOffsetPoint(heatmapY, heatmapX, keypoint, offsets), y = _a.y, x = _a.x; return { x: part.heatmapX * outputStride + x, y: part.heatmapY * outputStride + y }; } function clamp(a, min, max) { if (a < min) { return min; } if (a > max) { return max; } return a; } function squaredDistance(y1, x1, y2, x2) { var dy = y2 - y1; var dx = x2 - x1; return dy * dy + dx * dx; } function addVectors(a, b) { return { x: a.x + b.x, y: a.y + b.y }; } /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function computeDistance(embedding, pose, minPartScore) { if (minPartScore === void 0) { minPartScore = 0.3; } var distance = 0.0; var numKpt = 0; for (var p = 0; p < embedding.length; p++) { if (pose.keypoints[p].score > minPartScore) { numKpt += 1; distance += Math.pow((embedding[p].x - pose.keypoints[p].position.x), 2) + Math.pow((embedding[p].y - pose.keypoints[p].position.y), 2); } } if (numKpt === 0) { distance = Infinity; } else { distance = distance / numKpt; } return distance; } function convertToPositionInOuput(position, _a, _b, stride) { var padT = _a[0], padL = _a[1]; var scaleX = _b[0], scaleY = _b[1]; var y = Math.round(((padT + position.y + 1.0) * scaleY - 1.0) / stride); var x = Math.round(((padL + position.x + 1.0) * scaleX - 1.0) / stride); return { x: x, y: y }; } function getEmbedding(location, keypointIndex, convertToPosition, outputResolutionX, longOffsets, refineSteps, _a) { var height = _a[0], width = _a[1]; var newLocation = convertToPosition(location); var nn = newLocation.y * outputResolutionX + newLocation.x; var dy = longOffsets[NUM_KEYPOINTS * (2 * nn) + keypointIndex]; var dx = longOffsets[NUM_KEYPOINTS * (2 * nn + 1) + keypointIndex]; var y = location.y + dy; var x = location.x + dx; for (var t = 0; t < refineSteps; t++) { y = Math.min(y, height - 1); x = Math.min(x, width - 1); var newPos = convertToPosition({ x: x, y: y }); var nn_1 = newPos.y * outputResolutionX + newPos.x; dy = longOffsets[NUM_KEYPOINTS * (2 * nn_1) + keypointIndex]; dx = longOffsets[NUM_KEYPOINTS * (2 * nn_1 + 1) + keypointIndex]; y = y + dy; x = x + dx; } return { x: x, y: y }; } function matchEmbeddingToInstance(location, longOffsets, poses, numKptForMatching, _a, _b, outputResolutionX, _c, stride, refineSteps) { var padT = _a[0], padL = _a[1]; var scaleX = _b[0], scaleY = _b[1]; var height = _c[0], width = _c[1]; var embed = []; var convertToPosition = function (pair) { return convertToPositionInOuput(pair, [padT, padL], [scaleX, scaleY], stride); }; for (var keypointsIndex = 0; keypointsIndex < numKptForMatching; keypointsIndex++) { var embedding = getEmbedding(location, keypointsIndex, convertToPosition, outputResolutionX, longOffsets, refineSteps, [height, width]); embed.push(embedding); } var kMin = -1; var kMinDist = Infinity; for (var k = 0; k < poses.length; k++) { var dist = computeDistance(embed, poses[k]); if (dist < kMinDist) { kMin = k; kMinDist = dist; } } return kMin; } function getOutputResolution(_a, stride) { var inputResolutionY = _a[0], inputResolutionX = _a[1]; var outputResolutionX = Math.round((inputResolutionX - 1.0) / stride + 1.0); var outputResolutionY = Math.round((inputResolutionY - 1.0) / stride + 1.0); return [outputResolutionX, outputResolutionY]; } function decodeMultipleMasksCPU(segmentation, longOffsets, posesAboveScore, height, width, stride, _a, padding, refineSteps, numKptForMatching) { var inHeight = _a[0], inWidth = _a[1]; if (numKptForMatching === void 0) { numKptForMatching = 5; } var dataArrays = posesAboveScore.map(function (x) { return new Uint8Array(height * width).fill(0); }); var padT = padding.top, padL = padding.left; var _b = getScale([height, width], [inHeight, inWidth], padding), scaleX = _b[0], scaleY = _b[1]; var outputResolutionX = getOutputResolution([inHeight, inWidth], stride)[0]; for (var i = 0; i < height; i += 1) { for (var j = 0; j < width; j += 1) { var n = i * width + j; var prob = segmentation[n]; if (prob === 1) { var kMin = matchEmbeddingToInstance({ x: j, y: i }, longOffsets, posesAboveScore, numKptForMatching, [padT, padL], [scaleX, scaleY], outputResolutionX, [height, width], stride, refineSteps); if (kMin >= 0) { dataArrays[kMin][n] = 1; } } } } return dataArrays; } function decodeMultiplePartMasksCPU(segmentation, longOffsets, partSegmentaion, posesAboveScore, height, width, stride, _a, padding, refineSteps, numKptForMatching) { var inHeight = _a[0], inWidth = _a[1]; if (numKptForMatching === void 0) { numKptForMatching = 5; } var dataArrays = posesAboveScore.map(function (x) { return new Int32Array(height * width).fill(-1); }); var padT = padding.top, padL = padding.left; var _b = getScale([height, width], [inHeight, inWidth], padding), scaleX = _b[0], scaleY = _b[1]; var outputResolutionX = getOutputResolution([inHeight, inWidth], stride)[0]; for (var i = 0; i < height; i += 1) { for (var j = 0; j < width; j += 1) { var n = i * width + j; var prob = segmentation[n]; if (prob === 1) { var kMin = matchEmbeddingToInstance({ x: j, y: i }, longOffsets, posesAboveScore, numKptForMatching, [padT, padL], [scaleX, scaleY], outputResolutionX, [height, width], stride, refineSteps); if (kMin >= 0) { dataArrays[kMin][n] = partSegmentaion[n]; } } } } return dataArrays; } /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function decodeMultipleMasksWebGl(segmentation, longOffsets, posesAboveScore, height, width, stride, _a, padding, refineSteps, minKptScore, maxNumPeople) { var inHeight = _a[0], inWidth = _a[1]; // The height/width of the image/canvas itself. var _b = segmentation.shape, origHeight = _b[0], origWidth = _b[1]; // The height/width of the output of the model. var _c = longOffsets.shape.slice(0, 2), outHeight = _c[0], outWidth = _c[1]; var shapedLongOffsets = tf.reshape(longOffsets, [outHeight, outWidth, 2, NUM_KEYPOINTS]); // Make pose tensor of shape [MAX_NUM_PEOPLE, NUM_KEYPOINTS, 3] where // the last 3 coordinates correspond to the score, h and w coordinate of that // keypoint. var poseVals = new Float32Array(maxNumPeople * NUM_KEYPOINTS * 3).fill(0.0); for (var i = 0; i < posesAboveScore.length; i++) { var poseOffset = i * NUM_KEYPOINTS * 3; var pose = posesAboveScore[i]; for (var kp = 0; kp < NUM_KEYPOINTS; kp++) { var keypoint = pose.keypoints[kp]; var offset = poseOffset + kp * 3; poseVals[offset] = keypoint.score; poseVals[offset + 1] = keypoint.position.y; poseVals[offset + 2] = keypoint.position.x; } } var _d = getScale([height, width], [inHeight, inWidth], padding), scaleX = _d[0], scaleY = _d[1]; var posesTensor = tf.tensor(poseVals, [maxNumPeople, NUM_KEYPOINTS, 3]); var padT = padding.top, padL = padding.left; var program = { variableNames: ['segmentation', 'longOffsets', 'poses'], outputShape: [origHeight, origWidth], userCode: "\n int convertToPositionInOutput(int pos, int pad, float scale, int stride) {\n return round(((float(pos + pad) + 1.0) * scale - 1.0) / float(stride));\n }\n\n float convertToPositionInOutputFloat(\n int pos, int pad, float scale, int stride) {\n return ((float(pos + pad) + 1.0) * scale - 1.0) / float(stride);\n }\n\n float dist(float x1, float y1, float x2, float y2) {\n return pow(x1 - x2, 2.0) + pow(y1 - y2, 2.0);\n }\n\n float sampleLongOffsets(float h, float w, int d, int k) {\n float fh = fract(h);\n float fw = fract(w);\n int clH = int(ceil(h));\n int clW = int(ceil(w));\n int flH = int(floor(h));\n int flW = int(floor(w));\n float o11 = getLongOffsets(flH, flW, d, k);\n float o12 = getLongOffsets(flH, clW, d, k);\n float o21 = getLongOffsets(clH, flW, d, k);\n float o22 = getLongOffsets(clH, clW, d, k);\n float o1 = mix(o11, o12, fw);\n float o2 = mix(o21, o22, fw);\n return mix(o1, o2, fh);\n }\n\n int findNearestPose(int h, int w) {\n float prob = getSegmentation(h, w);\n if (prob < 1.0) {\n return -1;\n }\n\n // Done(Tyler): convert from output space h/w to strided space.\n float stridedH = convertToPositionInOutputFloat(\n h, ".concat(padT, ", ").concat(scaleY, ", ").concat(stride, ");\n float stridedW = convertToPositionInOutputFloat(\n w, ").concat(padL, ", ").concat(scaleX, ", ").concat(stride, ");\n\n float minDist = 1000000.0;\n int iMin = -1;\n for (int i = 0; i < ").concat(maxNumPeople, "; i++) {\n float curDistSum = 0.0;\n int numKpt = 0;\n for (int k = 0; k < ").concat(NUM_KEYPOINTS, "; k++) {\n float dy = sampleLongOffsets(stridedH, stridedW, 0, k);\n float dx = sampleLongOffsets(stridedH, stridedW, 1, k);\n\n float y = float(h) + dy;\n float x = float(w) + dx;\n\n for (int s = 0; s < ").concat(refineSteps, "; s++) {\n int yRounded = round(min(y, float(").concat(height - 1.0, ")));\n int xRounded = round(min(x, float(").concat(width - 1.0, ")));\n\n float yStrided = convertToPositionInOutputFloat(\n yRounded, ").concat(padT, ", ").concat(scaleY, ", ").concat(stride, ");\n float xStrided = convertToPositionInOutputFloat(\n xRounded, ").concat(padL, ", ").concat(scaleX, ", ").concat(stride, ");\n\n float dy = sampleLongOffsets(yStrided, xStrided, 0, k);\n float dx = sampleLongOffsets(yStrided, xStrided, 1, k);\n\n y = y + dy;\n x = x + dx;\n }\n\n float poseScore = getPoses(i, k, 0);\n float poseY = getPoses(i, k, 1);\n float poseX = getPoses(i, k, 2);\n if (poseScore > ").concat(minKptScore, ") {\n numKpt = numKpt + 1;\n curDistSum = curDistSum + dist(x, y, poseX, poseY);\n }\n }\n if (numKpt > 0 && curDistSum / float(numKpt) < minDist) {\n minDist = curDistSum / float(numKpt);\n iMin = i;\n }\n }\n return iMin;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int nearestPose = findNearestPose(coords[0], coords[1]);\n setOutput(float(nearestPose));\n }\n ") }; var webglBackend = tf.backend(); return webglBackend.compileAndRun(program, [segmentation, shapedLongOffsets, posesTensor]); } /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function toPersonKSegmentation(segmentation, k) { return tf.tidy(function () { return tf.cast(tf.equal(segmentation, tf.scalar(k)), 'int32'); }); } function toPersonKPartSegmentation(segmentation, bodyParts, k) { return tf.tidy(function () { return tf.sub(tf.mul(tf.cast(tf.equal(segmentation, tf.scalar(k)), 'int32'), tf.add(bodyParts, 1)), 1); }); } function isWebGlBackend() { return getBackend() === 'webgl'; } function decodePersonInstanceMasks(segmentation, longOffsets, poses, height, width, stride, _a, padding, minPoseScore, refineSteps, minKeypointScore, maxNumPeople) { var inHeight = _a[0], inWidth = _a[1]; if (minPoseScore === void 0) { minPoseScore = 0.2; } if (refineSteps === void 0) { refineSteps = 8; } if (minKeypointScore === void 0) { minKeypointScore = 0.3; } if (maxNumPeople === void 0) { maxNumPeople = 10; } return __awaiter(this, void 0, void 0, function () { var posesAboveScore, personSegmentationsData, personSegmentations, segmentationsData, longOffsetsData; return __generator(this, function (_b) { switch (_b.label) { case 0: posesAboveScore = poses.filter(function (pose) { return pose.score >= minPoseScore; }); if (!isWebGlBackend()) return [3 /*break*/, 2]; personSegmentations = tf.tidy(function () { var masksTensorInfo = decodeMultipleMasksWebGl(segmentation, longOffsets, posesAboveScore, height, width, stride, [inHeight, inWidth], padding, refineSteps, minKeypointScore, maxNumPeople); var masksTensor = tf.engine().makeTensorFromDataId(masksTensorInfo.dataId, masksTensorInfo.shape, masksTensorInfo.dtype); return posesAboveScore.map(function (_, k) { return toPersonKSegmentation(masksTensor, k); }); }); return [4 /*yield*/, Promise.all(personSegmentations.map(function (mask) { return mask.data(); }))]; case 1: personSegmentationsData = (_b.sent()); personSegmentations.forEach(function (x) { return x.dispose(); }); return [3 /*break*/, 5]; case 2: return [4 /*yield*/, segmentation.data()]; case 3: segmentationsData = _b.sent(); return [4 /*yield*/, longOffsets.data()]; case 4: longOffsetsData = _b.sent(); personSegmentationsData = decodeMultipleMasksCPU(segmentationsData, longOffsetsData, posesAboveScore, height, width, stride, [inHeight, inWidth], padding, refineSteps); _b.label = 5; case 5: return [2 /*return*/, personSegmentationsData.map(function (data, i) { return ({ data: data, pose: posesAboveScore[i], width: width, height: height }); })]; } }); }); } function decodePersonInstancePartMasks(segmentation, longOffsets, partSegmentation, poses, height, width, stride, _a, padding, minPoseScore, refineSteps, minKeypointScore, maxNumPeople) { var inHeight = _a[0], inWidth = _a[1]; if (minPoseScore === void 0) { minPoseScore = 0.2; } if (refineSteps === void 0) { refineSteps = 8; } if (minKeypointScore === void 0) { minKeypointScore = 0.3; } if (maxNumPeople === void 0) { maxNumPeople = 10; } return __awaiter(this, void 0, void 0, function () { var posesAboveScore, partSegmentationsByPersonData, partSegmentations, segmentationsData, longOffsetsData, partSegmentaionData; return __generator(this, function (_b) { switch (_b.label) { case 0: posesAboveScore = poses.filter(function (pose) { return pose.score >= minPoseScore; }); if (!isWebGlBackend()) return [3 /*break*/, 2]; partSegmentations = tf.tidy(function () { var masksTensorInfo = decodeMultipleMasksWebGl(segmentation, longOffsets, posesAboveScore, height, width, stride, [inHeight, inWidth], padding, refineSteps, minKeypointScore, maxNumPeople); var masksTensor = tf.engine().makeTensorFromDataId(masksTensorInfo.dataId, masksTensorInfo.shape, masksTensorInfo.dtype); return posesAboveScore.map(function (_, k) { return toPersonKPartSegmentation(masksTensor, partSegmentation, k); }); }); return [4 /*yield*/, Promise.all(partSegmentations.map(function (x) { return x.data(); }))]; case 1: partSegmentationsByPersonData = (_b.sent()); partSegmentations.forEach(function (x) { return x.dispose(); }); return [3 /*break*/, 6]; case 2: return [4 /*yield*/, segmentation.data()]; case 3: segmentationsData = _b.sent(); return [4 /*yield*/, longOffsets.data()]; case 4: longOffsetsData = _b.sent(); return [4 /*yield*/, partSegmentation.data()]; case 5: partSegmentaionData = _b.sent(); partSegmentationsByPersonData = decodeMultiplePartMasksCPU(segmentationsData, longOffsetsData, partSegmentaionData, posesAboveScore, height, width, stride, [inHeight, inWidth], padding, refineSteps); _b.label = 6; case 6: return [2 /*return*/, partSegmentationsByPersonData.map(function (data, k) { return ({ pose: posesAboveScore[k], data: data, height: height, width: width }); })]; } }); }); } /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // algorithm based on Coursera Lecture from Algorithms, Part 1: // https://www.coursera.org/learn/algorithms-part1/lecture/ZjoSM/heapsort function half(k) { return Math.floor(k / 2); } var MaxHeap = /** @class */ (function () { function MaxHeap(maxSize, getElementValue) { this.priorityQueue = new Array(maxSize); this.numberOfElements = -1; this.getElementValue = getElementValue; } MaxHeap.prototype.enqueue = function (x) { this.priorityQueue[++this.numberOfElements] = x; this.swim(this.numberOfElements); }; MaxHeap.prototype.dequeue = function () { var max = this.priorityQueue[0]; this.exchange(0, this.numberOfElements--); this.sink(0); this.priorityQueue[this.numberOfElements + 1] = null; return max; }; MaxHeap.prototype.empty = function () { return this.numberOfElements === -1; }; MaxHeap.prototype.size = function () { return this.numberOfElements + 1; }; MaxHeap.prototype.all = function () { return this.priorityQueue.slice(0, this.numberOfElements + 1); }; MaxHeap.prototype.max = function () { return this.priorityQueue[0]; }; MaxHeap.prototype.swim = function (k) { while (k > 0 && this.less(half(k), k)) { this.exchange(k, half(k)); k = half(k); } }; MaxHeap.prototype.sink = function (k) { while (2 * k <= this.numberOfElements) { var j = 2 * k; if (j < this.numberOfElements && this.less(j, j + 1)) { j++; } if (!this.less(k, j)) { break; } this.exchange(k, j); k = j; } }; MaxHeap.prototype.getValueAt = function (i) { return this.getElementValue(this.priorityQueue[i]); }; MaxHeap.prototype.less = function (i, j) { return this.getValueAt(i) < this.getValueAt(j); }; MaxHeap.prototype.exchange = function (i, j) { var t = this.priorityQueue[i]; this.priorityQueue[i] = this.priorityQueue[j]; this.priorityQueue[j] = t; }; return MaxHeap; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores) { var _a = scores.shape, height = _a[0], width = _a[1]; var localMaximum = true; var yStart = Math.max(heatmapY - localMaximumRadius, 0); var yEnd = Math.min(heatmapY + localMaximumRadius + 1, height); for (var yCurrent = yStart; yCurrent < yEnd; ++yCurrent) { var xStart = Math.max(heatmapX - localMaximumRadius, 0); var xEnd = Math.min(heatmapX + localMaximumRadius + 1, width); for (var xCurrent = xStart; xCurrent < xEnd; ++xCurrent) { if (scores.get(yCurrent, xCurrent, keypointId) > score) { localMaximum = false; break; } } if (!localMaximum) { break; } } return localMaximum; } /** * Builds a priority queue with part candidate positions for a specific image in * the batch. For this we find all local maxima in the score maps with score * values above a threshold. We create a single priority queue across all parts. */ function buildPartWithScoreQueue(scoreThreshold, localMaximumRadius, scores) { var _a = scores.shape, height = _a[0], width = _a[1], numKeypoints = _a[2]; var queue = new MaxHeap(height * width * numKeypoints, function (_a) { var score = _a.score; return score; }); for (var heatmapY = 0; heatmapY < height; ++heatmapY) { for (var heatmapX = 0; heatmapX < width; ++heatmapX) { for (var keypointId = 0; keypointId < numKeypoints; ++keypointId) { var score = scores.get(heatmapY, heatmapX, keypointId); // Only consider parts with score greater or equal to threshold as // root candidates. if (score < scoreThreshold) { continue; } // Only consider keypoints whose score is maximum in a local window. if (scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores)) { queue.enqueue({ score: score, part: { heatmapY: heatmapY, heatmapX: heatmapX, id: keypointId } }); } } } } return queue; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var parentChildrenTuples = POSE_CHAIN.map(function (_a) { var parentJoinName = _a[0], childJoinName = _a[1]; return ([PART_IDS[parentJoinName], PART_IDS[childJoinName]]); }); var parentToChildEdges = parentChildrenTuples.map(function (_a) { var childJointId = _a[1]; return childJointId; }); var childToParentEdges = parentChildrenTuples.map(function (_a) { var parentJointId = _a[0]; return parentJointId; }); function getDisplacement(edgeId, point, displacements) { var numEdges = displacements.shape[2] / 2; return { y: displacements.get(point.y, point.x, edgeId), x: displacements.get(point.y, point.x, numEdges + edgeId) }; } function getStridedIndexNearPoint(point, outputStride, height, width) { return { y: clamp(Math.round(point.y / outputStride), 0, height - 1), x: clamp(Math.round(point.x / outputStride), 0, width - 1) }; } /** * We get a new keypoint along the `edgeId` for the pose instance, assuming * that the position of the `idSource` part is already known. For this, we * follow the displacement vector from the source to target part (stored in * the `i`-t channel of the displacement tensor). The displaced keypoint * vector is refined using the offset vector by `offsetRefineStep` times. */ function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep) { if (offsetRefineStep === void 0) { offsetRefineStep = 2; } var _a = scoresBuffer.shape, height = _a[0], width = _a[1]; // Nearest neighbor interpolation for the source->target displacements. var sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, outputStride, height, width); var displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements); var displacedPoint = addVectors(sourceKeypoint.position, displacement); var targetKeypoint = displacedPoint; for (var i = 0; i < offsetRefineStep; i++) { var targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width); var offsetPoint = getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, offsets); targetKeypoint = addVectors({ x: targetKeypointIndices.x * outputStride, y: targetKeypointIndices.y * outputStride }, { x: offsetPoint.x, y: offsetPoint.y }); } var targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width); var score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId); return { position: targetKeypoint, part: PART_NAMES[targetKeypointId], score: score }; } /** * Follows the displacement fields to decode the full pose of the object * instance given the position of a part that acts as root. * * @return An array of decoded keypoints and their scores for a single pose */ function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) { var numParts = scores.shape[2]; var numEdges = parentToChildEdges.length; var instanceKeypoints = new Array(numParts); // Start a new detection instance at the position of the root. var rootPart = root.part, rootScore = root.score; var rootPoint = getImageCoords(rootPart, outputStride, offsets); instanceKeypoints[rootPart.id] = { score: rootScore, part: PART_NAMES[rootPart.id], position: rootPoint }; // Decode the part positions upwards in the tree, following the backward // displacements. for (var edge = numEdges - 1; edge >= 0; --edge) { var sourceKeypointId = parentToChildEdges[edge]; var targetKeypointId = childToParentEdges[edge]; if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) { instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsBwd); } } // Decode the part positions downwards in the tree, following the forward // displacements. for (var edge = 0; edge < numEdges; ++edge) { var sourceKeypointId = childToParentEdges[edge]; var targetKeypointId = parentToChildEdges[edge]; if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) { instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsFwd); } } return instanceKeypoints; } /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, _a, keypointId) { var x = _a.x, y = _a.y; return poses.some(function (_a) { var keypoints = _a.keypoints; var correspondingKeypoint = keypoints[keypointId].position; return squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius; }); } /* Score the newly proposed object instance without taking into account * the scores of the parts that overlap with any previously detected * instance. */ function g