@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
53 lines • 7.7 kB
JavaScript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { UnsortedSegmentSum } from '../kernel_names';
import { expandDims } from '../ops/expand_dims';
import { gather } from '../ops/gather';
import { greaterEqual } from '../ops/greater_equal';
import { logicalAnd } from '../ops/logical_and';
import { maximum } from '../ops/maximum';
import { ones } from '../ops/ones';
import { scalar } from '../ops/scalar';
import { where } from '../ops/where';
import { zerosLike } from '../ops/zeros_like';
export const unsortedSegmentSumGradConfig = {
kernelName: UnsortedSegmentSum,
inputsToSave: ['segmentIds'],
gradFunc: (dy, saved) => {
const [segmentIds] = saved;
const derX = () => {
return gatherDropNegatives(dy, segmentIds);
};
return { x: derX };
}
};
function gatherDropNegatives(x, indices) {
// Helper function for unsorted segment ops. Gathers params for
// positive segment ids and gathers 0 for inputs with negative segment id.
// Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
const zeroClippedIndices = maximum(indices, zerosLike(indices));
const gathered = gather(x, zeroClippedIndices);
let isPositive = greaterEqual(indices, scalar(0, 'int32'));
const numIters = gathered.rank - isPositive.rank;
for (let i = 0; i < numIters; ++i) {
isPositive = expandDims(isPositive, i + 1);
}
isPositive = logicalAnd(isPositive, ones(gathered.shape, 'bool'));
const zeroSlice = zerosLike(gathered);
return where(isPositive, gathered, zeroSlice);
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVW5zb3J0ZWRTZWdtZW50U3VtX2dyYWQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL2dyYWRpZW50cy9VbnNvcnRlZFNlZ21lbnRTdW1fZ3JhZC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsa0JBQWtCLEVBQUMsTUFBTSxpQkFBaUIsQ0FBQztBQUVuRCxPQUFPLEVBQUMsVUFBVSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFDOUMsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLGVBQWUsQ0FBQztBQUNyQyxPQUFPLEVBQUMsWUFBWSxFQUFDLE1BQU0sc0JBQXNCLENBQUM7QUFDbEQsT0FBTyxFQUFDLFVBQVUsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBQzlDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUN2QyxPQUFPLEVBQUMsSUFBSSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxlQUFlLENBQUM7QUFDckMsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUNuQyxPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sbUJBQW1CLENBQUM7QUFHNUMsTUFBTSxDQUFDLE1BQU0sNEJBQTRCLEdBQWU7SUFDdEQsVUFBVSxFQUFFLGtCQUFrQjtJQUM5QixZQUFZLEVBQUUsQ0FBQyxZQUFZLENBQUM7SUFDNUIsUUFBUSxFQUFFLENBQUMsRUFBVSxFQUFFLEtBQWUsRUFBRSxFQUFFO1FBQ3hDLE1BQU0sQ0FBQyxVQUFVLENBQUMsR0FBRyxLQUFLLENBQUM7UUFFM0IsTUFBTSxJQUFJLEdBQUcsR0FBRyxFQUFFO1lBQ2hCLE9BQU8sbUJBQW1CLENBQUMsRUFBRSxFQUFFLFVBQXNCLENBQUMsQ0FBQztRQUN6RCxDQUFDLENBQUM7UUFDRixPQUFPLEVBQUMsQ0FBQyxFQUFFLElBQUksRUFBQyxDQUFDO0lBQ25CLENBQUM7Q0FDRixDQUFDO0FBRUYsU0FBUyxtQkFBbUIsQ0FBbUIsQ0FBSSxFQUFFLE9BQWlCO0lBQ3BFLCtEQUErRDtJQUMvRCwwRUFBMEU7SUFDMUUsdUVBQXVFO0lBQ3ZFLE1BQU0sa0JBQWtCLEdBQUcsT0FBTyxDQUFDLE9BQU8sRUFBRSxTQUFTLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQztJQUNoRSxNQUFNLFFBQVEsR0FBRyxNQUFNLENBQUMsQ0FBQyxFQUFFLGtCQUE4QixDQUFDLENBQUM7SUFDM0QsSUFBSSxVQUFVLEdBQUcsWUFBWSxDQUFDLE9BQU8sRUFBRSxNQUFNLENBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDLENBQUM7SUFDM0QsTUFBTSxRQUFRLEdBQUcsUUFBUSxDQUFDLElBQUksR0FBRyxVQUFVLENBQUMsSUFBSSxDQUFDO0lBQ2pELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxRQUFRLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDakMsVUFBVSxHQUFHLFVBQVUsQ0FBQyxVQUFVLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDO0tBQzVDO0lBQ0QsVUFBVSxHQUFHLFVBQVUsQ0FBQyxVQUFVLEVBQUUsSUFBSSxDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLENBQUMsQ0FBQztJQUNsRSxNQUFNLFNBQVMsR0FBRyxTQUFTLENBQUMsUUFBUSxDQUFDLENBQUM7SUFDdEMsT0FBTyxLQUFLLENBQUMsVUFBVSxFQUFFLFFBQVEsRUFBRSxTQUFTLENBQUMsQ0FBQztBQUNoRCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1Vuc29ydGVkU2VnbWVudFN1bX0gZnJvbSAnLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7R3JhZENvbmZpZ30gZnJvbSAnLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7ZXhwYW5kRGltc30gZnJvbSAnLi4vb3BzL2V4cGFuZF9kaW1zJztcbmltcG9ydCB7Z2F0aGVyfSBmcm9tICcuLi9vcHMvZ2F0aGVyJztcbmltcG9ydCB7Z3JlYXRlckVxdWFsfSBmcm9tICcuLi9vcHMvZ3JlYXRlcl9lcXVhbCc7XG5pbXBvcnQge2xvZ2ljYWxBbmR9IGZyb20gJy4uL29wcy9sb2dpY2FsX2FuZCc7XG5pbXBvcnQge21heGltdW19IGZyb20gJy4uL29wcy9tYXhpbXVtJztcbmltcG9ydCB7b25lc30gZnJvbSAnLi4vb3BzL29uZXMnO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL29wcy9zY2FsYXInO1xuaW1wb3J0IHt3aGVyZX0gZnJvbSAnLi4vb3BzL3doZXJlJztcbmltcG9ydCB7emVyb3NMaWtlfSBmcm9tICcuLi9vcHMvemVyb3NfbGlrZSc7XG5pbXBvcnQge1RlbnNvciwgVGVuc29yMUR9IGZyb20gJy4uL3RlbnNvcic7XG5cbmV4cG9ydCBjb25zdCB1bnNvcnRlZFNlZ21lbnRTdW1HcmFkQ29uZmlnOiBHcmFkQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBVbnNvcnRlZFNlZ21lbnRTdW0sXG4gIGlucHV0c1RvU2F2ZTogWydzZWdtZW50SWRzJ10sXG4gIGdyYWRGdW5jOiAoZHk6IFRlbnNvciwgc2F2ZWQ6IFRlbnNvcltdKSA9PiB7XG4gICAgY29uc3QgW3NlZ21lbnRJZHNdID0gc2F2ZWQ7XG5cbiAgICBjb25zdCBkZXJYID0gKCkgPT4ge1xuICAgICAgcmV0dXJuIGdhdGhlckRyb3BOZWdhdGl2ZXMoZHksIHNlZ21lbnRJZHMgYXMgVGVuc29yMUQpO1xuICAgIH07XG4gICAgcmV0dXJuIHt4OiBkZXJYfTtcbiAgfVxufTtcblxuZnVuY3Rpb24gZ2F0aGVyRHJvcE5lZ2F0aXZlczxUIGV4dGVuZHMgVGVuc29yPih4OiBULCBpbmRpY2VzOiBUZW5zb3IxRCkge1xuICAvLyBIZWxwZXIgZnVuY3Rpb24gZm9yIHVuc29ydGVkIHNlZ21lbnQgb3BzLiBHYXRoZXJzIHBhcmFtcyBmb3JcbiAgLy8gcG9zaXRpdmUgc2VnbWVudCBpZHMgYW5kIGdhdGhlcnMgMCBmb3IgaW5wdXRzIHdpdGggbmVnYXRpdmUgc2VnbWVudCBpZC5cbiAgLy8gTWlycm9ycyBfR2F0aGVyRHJvcE5lZ2F0aXZlcyBmcm9tIHRlbnNvcmZsb3cvcHl0aG9uL29wcy9tYXRoX2dyYWQucHlcbiAgY29uc3QgemVyb0NsaXBwZWRJbmRpY2VzID0gbWF4aW11bShpbmRpY2VzLCB6ZXJvc0xpa2UoaW5kaWNlcykpO1xuICBjb25zdCBnYXRoZXJlZCA9IGdhdGhlcih4LCB6ZXJvQ2xpcHBlZEluZGljZXMgYXMgVGVuc29yMUQpO1xuICBsZXQgaXNQb3NpdGl2ZSA9IGdyZWF0ZXJFcXVhbChpbmRpY2VzLCBzY2FsYXIoMCwgJ2ludDMyJykpO1xuICBjb25zdCBudW1JdGVycyA9IGdhdGhlcmVkLnJhbmsgLSBpc1Bvc2l0aXZlLnJhbms7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbnVtSXRlcnM7ICsraSkge1xuICAgIGlzUG9zaXRpdmUgPSBleHBhbmREaW1zKGlzUG9zaXRpdmUsIGkgKyAxKTtcbiAgfVxuICBpc1Bvc2l0aXZlID0gbG9naWNhbEFuZChpc1Bvc2l0aXZlLCBvbmVzKGdhdGhlcmVkLnNoYXBlLCAnYm9vbCcpKTtcbiAgY29uc3QgemVyb1NsaWNlID0gemVyb3NMaWtlKGdhdGhlcmVkKTtcbiAgcmV0dXJuIHdoZXJlKGlzUG9zaXRpdmUsIGdhdGhlcmVkLCB6ZXJvU2xpY2UpO1xufVxuIl19