@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
161 lines (136 loc) • 4.95 kB
text/typescript
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {SegOpInfo} from '../../ops/segment_util';
import {GPGPUProgram} from './gpgpu_math';
export class SegmentOpProgram implements GPGPUProgram {
variableNames = ['x', 'segmentIds'];
outputShape: number[];
userCode: string;
constructor(segOpInfo: SegOpInfo, segOpType: 'unsortedSegmentSum') {
const windowSize = segOpInfo.windowSize;
const batchSize = segOpInfo.batchSize;
const inSize = segOpInfo.inSize;
const numSegments = segOpInfo.numSegments;
const outSize = numSegments * Math.ceil(inSize / windowSize);
this.outputShape = [batchSize, outSize];
const initializationValue = '0.0';
const returnValue = `sumValue`;
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
const windowSizeVec4Remainder = windowSize % 4;
const updateSnippet = `
sumValue += dot(values, segFilter);
`;
let checkValueOutOfBounds = '';
if (inSize % windowSize > 0) {
checkValueOutOfBounds = `
if (inIdx < 0 || inIdx >= ${inSize}) {
return initializationValue;
}
`;
}
let checkSegmentIdOutOfBounds = '';
if (inSize % windowSize > 0) {
checkSegmentIdOutOfBounds = `
if (inIdx < 0 || inIdx >= ${inSize}) {
return -1.0;
}
`;
}
this.userCode = `
const float initializationValue = ${initializationValue};
float getValue(int batch, int inIdx) {
${checkValueOutOfBounds}
return getX(batch, inIdx);
}
float getSegmentIdAtIndex(int inIdx) {
${checkSegmentIdOutOfBounds}
return getSegmentIds(inIdx);
}
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = int(floor(float(outIdx) / float(
${numSegments})) * float(${windowSize}));
int currentSeg = int(mod(float(outIdx), float(${numSegments})));
float sumValue = 0.0;
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
);
${updateSnippet}
}
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder === 1}) {
vec4 values = vec4(
getValue(batch, inIdx),
initializationValue,
initializationValue,
initializationValue
);
int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
0,
0,
0
);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 2}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
initializationValue,
initializationValue
);
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
0,
0
);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 3}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
initializationValue
);
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
0
);
${updateSnippet}
}
setOutput(${returnValue});
}
`;
}
}