UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

53 lines 7.7 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { UnsortedSegmentSum } from '../kernel_names'; import { expandDims } from '../ops/expand_dims'; import { gather } from '../ops/gather'; import { greaterEqual } from '../ops/greater_equal'; import { logicalAnd } from '../ops/logical_and'; import { maximum } from '../ops/maximum'; import { ones } from '../ops/ones'; import { scalar } from '../ops/scalar'; import { where } from '../ops/where'; import { zerosLike } from '../ops/zeros_like'; export const unsortedSegmentSumGradConfig = { kernelName: UnsortedSegmentSum, inputsToSave: ['segmentIds'], gradFunc: (dy, saved) => { const [segmentIds] = saved; const derX = () => { return gatherDropNegatives(dy, segmentIds); }; return { x: derX }; } }; function gatherDropNegatives(x, indices) { // Helper function for unsorted segment ops. Gathers params for // positive segment ids and gathers 0 for inputs with negative segment id. // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py const zeroClippedIndices = maximum(indices, zerosLike(indices)); const gathered = gather(x, zeroClippedIndices); let isPositive = greaterEqual(indices, scalar(0, 'int32')); const numIters = gathered.rank - isPositive.rank; for (let i = 0; i < numIters; ++i) { isPositive = expandDims(isPositive, i + 1); } isPositive = logicalAnd(isPositive, ones(gathered.shape, 'bool')); const zeroSlice = zerosLike(gathered); return where(isPositive, gathered, zeroSlice); } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVW5zb3J0ZWRTZWdtZW50U3VtX2dyYWQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL2dyYWRpZW50cy9VbnNvcnRlZFNlZ21lbnRTdW1fZ3JhZC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsa0JBQWtCLEVBQUMsTUFBTSxpQkFBaUIsQ0FBQztBQUVuRCxPQUFPLEVBQUMsVUFBVSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFDOUMsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLGVBQWUsQ0FBQztBQUNyQyxPQUFPLEVBQUMsWUFBWSxFQUFDLE1BQU0sc0JBQXNCLENBQUM7QUFDbEQsT0FBTyxFQUFDLFVBQVUsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBQzlDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUN2QyxPQUFPLEVBQUMsSUFBSSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxlQUFlLENBQUM7QUFDckMsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUNuQyxPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sbUJBQW1CLENBQUM7QUFHNUMsTUFBTSxDQUFDLE1BQU0sNEJBQTRCLEdBQWU7SUFDdEQsVUFBVSxFQUFFLGtCQUFrQjtJQUM5QixZQUFZLEVBQUUsQ0FBQyxZQUFZLENBQUM7SUFDNUIsUUFBUSxFQUFFLENBQUMsRUFBVSxFQUFFLEtBQWUsRUFBRSxFQUFFO1FBQ3hDLE1BQU0sQ0FBQyxVQUFVLENBQUMsR0FBRyxLQUFLLENBQUM7UUFFM0IsTUFBTSxJQUFJLEdBQUcsR0FBRyxFQUFFO1lBQ2hCLE9BQU8sbUJBQW1CLENBQUMsRUFBRSxFQUFFLFVBQXNCLENBQUMsQ0FBQztRQUN6RCxDQUFDLENBQUM7UUFDRixPQUFPLEVBQUMsQ0FBQyxFQUFFLElBQUksRUFBQyxDQUFDO0lBQ25CLENBQUM7Q0FDRixDQUFDO0FBRUYsU0FBUyxtQkFBbUIsQ0FBbUIsQ0FBSSxFQUFFLE9BQWlCO0lBQ3BFLCtEQUErRDtJQUMvRCwwRUFBMEU7SUFDMUUsdUVBQXVFO0lBQ3ZFLE1BQU0sa0JBQWtCLEdBQUcsT0FBTyxDQUFDLE9BQU8sRUFBRSxTQUFTLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQztJQUNoRSxNQUFNLFFBQVEsR0FBRyxNQUFNLENBQUMsQ0FBQyxFQUFFLGtCQUE4QixDQUFDLENBQUM7SUFDM0QsSUFBSSxVQUFVLEdBQUcsWUFBWSxDQUFDLE9BQU8sRUFBRSxNQUFNLENBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDLENBQUM7SUFDM0QsTUFBTSxRQUFRLEdBQUcsUUFBUSxDQUFDLElBQUksR0FBRyxVQUFVLENBQUMsSUFBSSxDQUFDO0lBQ2pELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxRQUFRLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDakMsVUFBVSxHQUFHLFVBQVUsQ0FBQyxVQUFVLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDO0tBQzVDO0lBQ0QsVUFBVSxHQUFHLFVBQVUsQ0FBQyxVQUFVLEVBQUUsSUFBSSxDQUFDLFFBQVEsQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLENBQUMsQ0FBQztJQUNsRSxNQUFNLFNBQVMsR0FBRyxTQUFTLENBQUMsUUFBUSxDQUFDLENBQUM7SUFDdEMsT0FBTyxLQUFLLENBQUMsVUFBVSxFQUFFLFFBQVEsRUFBRSxTQUFTLENBQUMsQ0FBQztBQUNoRCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1Vuc29ydGVkU2VnbWVudFN1bX0gZnJvbSAnLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7R3JhZENvbmZpZ30gZnJvbSAnLi4va2VybmVsX3JlZ2lzdHJ5JztcbmltcG9ydCB7ZXhwYW5kRGltc30gZnJvbSAnLi4vb3BzL2V4cGFuZF9kaW1zJztcbmltcG9ydCB7Z2F0aGVyfSBmcm9tICcuLi9vcHMvZ2F0aGVyJztcbmltcG9ydCB7Z3JlYXRlckVxdWFsfSBmcm9tICcuLi9vcHMvZ3JlYXRlcl9lcXVhbCc7XG5pbXBvcnQge2xvZ2ljYWxBbmR9IGZyb20gJy4uL29wcy9sb2dpY2FsX2FuZCc7XG5pbXBvcnQge21heGltdW19IGZyb20gJy4uL29wcy9tYXhpbXVtJztcbmltcG9ydCB7b25lc30gZnJvbSAnLi4vb3BzL29uZXMnO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL29wcy9zY2FsYXInO1xuaW1wb3J0IHt3aGVyZX0gZnJvbSAnLi4vb3BzL3doZXJlJztcbmltcG9ydCB7emVyb3NMaWtlfSBmcm9tICcuLi9vcHMvemVyb3NfbGlrZSc7XG5pbXBvcnQge1RlbnNvciwgVGVuc29yMUR9IGZyb20gJy4uL3RlbnNvcic7XG5cbmV4cG9ydCBjb25zdCB1bnNvcnRlZFNlZ21lbnRTdW1HcmFkQ29uZmlnOiBHcmFkQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBVbnNvcnRlZFNlZ21lbnRTdW0sXG4gIGlucHV0c1RvU2F2ZTogWydzZWdtZW50SWRzJ10sXG4gIGdyYWRGdW5jOiAoZHk6IFRlbnNvciwgc2F2ZWQ6IFRlbnNvcltdKSA9PiB7XG4gICAgY29uc3QgW3NlZ21lbnRJZHNdID0gc2F2ZWQ7XG5cbiAgICBjb25zdCBkZXJYID0gKCkgPT4ge1xuICAgICAgcmV0dXJuIGdhdGhlckRyb3BOZWdhdGl2ZXMoZHksIHNlZ21lbnRJZHMgYXMgVGVuc29yMUQpO1xuICAgIH07XG4gICAgcmV0dXJuIHt4OiBkZXJYfTtcbiAgfVxufTtcblxuZnVuY3Rpb24gZ2F0aGVyRHJvcE5lZ2F0aXZlczxUIGV4dGVuZHMgVGVuc29yPih4OiBULCBpbmRpY2VzOiBUZW5zb3IxRCkge1xuICAvLyBIZWxwZXIgZnVuY3Rpb24gZm9yIHVuc29ydGVkIHNlZ21lbnQgb3BzLiBHYXRoZXJzIHBhcmFtcyBmb3JcbiAgLy8gcG9zaXRpdmUgc2VnbWVudCBpZHMgYW5kIGdhdGhlcnMgMCBmb3IgaW5wdXRzIHdpdGggbmVnYXRpdmUgc2VnbWVudCBpZC5cbiAgLy8gTWlycm9ycyBfR2F0aGVyRHJvcE5lZ2F0aXZlcyBmcm9tIHRlbnNvcmZsb3cvcHl0aG9uL29wcy9tYXRoX2dyYWQucHlcbiAgY29uc3QgemVyb0NsaXBwZWRJbmRpY2VzID0gbWF4aW11bShpbmRpY2VzLCB6ZXJvc0xpa2UoaW5kaWNlcykpO1xuICBjb25zdCBnYXRoZXJlZCA9IGdhdGhlcih4LCB6ZXJvQ2xpcHBlZEluZGljZXMgYXMgVGVuc29yMUQpO1xuICBsZXQgaXNQb3NpdGl2ZSA9IGdyZWF0ZXJFcXVhbChpbmRpY2VzLCBzY2FsYXIoMCwgJ2ludDMyJykpO1xuICBjb25zdCBudW1JdGVycyA9IGdhdGhlcmVkLnJhbmsgLSBpc1Bvc2l0aXZlLnJhbms7XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbnVtSXRlcnM7ICsraSkge1xuICAgIGlzUG9zaXRpdmUgPSBleHBhbmREaW1zKGlzUG9zaXRpdmUsIGkgKyAxKTtcbiAgfVxuICBpc1Bvc2l0aXZlID0gbG9naWNhbEFuZChpc1Bvc2l0aXZlLCBvbmVzKGdhdGhlcmVkLnNoYXBlLCAnYm9vbCcpKTtcbiAgY29uc3QgemVyb1NsaWNlID0gemVyb3NMaWtlKGdhdGhlcmVkKTtcbiAgcmV0dXJuIHdoZXJlKGlzUG9zaXRpdmUsIGdhdGhlcmVkLCB6ZXJvU2xpY2UpO1xufVxuIl19