UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

161 lines (136 loc) 4.95 kB
/** * @license * Copyright 2018 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {SegOpInfo} from '../../ops/segment_util'; import {GPGPUProgram} from './gpgpu_math'; export class SegmentOpProgram implements GPGPUProgram { variableNames = ['x', 'segmentIds']; outputShape: number[]; userCode: string; constructor(segOpInfo: SegOpInfo, segOpType: 'unsortedSegmentSum') { const windowSize = segOpInfo.windowSize; const batchSize = segOpInfo.batchSize; const inSize = segOpInfo.inSize; const numSegments = segOpInfo.numSegments; const outSize = numSegments * Math.ceil(inSize / windowSize); this.outputShape = [batchSize, outSize]; const initializationValue = '0.0'; const returnValue = `sumValue`; const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; const updateSnippet = ` sumValue += dot(values, segFilter); `; let checkValueOutOfBounds = ''; if (inSize % windowSize > 0) { checkValueOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return initializationValue; } `; } let checkSegmentIdOutOfBounds = ''; if (inSize % windowSize > 0) { checkSegmentIdOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return -1.0; } `; } this.userCode = ` const float initializationValue = ${initializationValue}; float getValue(int batch, int inIdx) { ${checkValueOutOfBounds} return getX(batch, inIdx); } float getSegmentIdAtIndex(int inIdx) { ${checkSegmentIdOutOfBounds} return getSegmentIds(inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = int(floor(float(outIdx) / float( ${numSegments})) * float(${windowSize})); int currentSeg = int(mod(float(outIdx), float(${numSegments}))); float sumValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0 ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, inIdx), initializationValue, initializationValue, initializationValue ); int inIdxSeg = int(getSegmentIdAtIndex(inIdx)); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, 0, 0, 0 ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), initializationValue, initializationValue ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, 0, 0 ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), initializationValue ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0, 0 ); ${updateSnippet} } setOutput(${returnValue}); } `; } }