@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
146 lines (134 loc) • 5.76 kB
text/typescript
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {assert} from '../../util';
import {getChannels} from '../packing_util';
import {GPGPUProgram} from './gpgpu_math';
import {getCoordsDataType} from './shader_compiler';
export class ArgMinMaxPackedProgram implements GPGPUProgram {
variableNames = ['A'];
outputShape: number[];
userCode: string;
usesPackedTextures = true;
constructor(
shape: number[], windowSize: number, op: 'max'|'min',
firstPass: boolean) {
assert(
shape.length > 2,
() => `Packed arg${
op.charAt(0).toUpperCase() +
op.slice(1)} supports only inputs with rank above 2.`);
const inSize = shape[shape.length - 1];
const outSize = Math.ceil(inSize / windowSize);
this.outputShape = shape.slice(0, -1);
if (outSize > 1) {
this.outputShape.push(outSize);
}
if (!firstPass) {
this.variableNames.push('bestIndicesA');
}
const outShape = this.outputShape;
const rank = outShape.length;
const dtype = getCoordsDataType(rank);
const coords = getChannels('coords', rank);
let sourceLocSetup;
let sourceRank;
if (outSize === 1) {
sourceRank = rank + 1;
const sourceLocDType = getCoordsDataType(sourceRank);
sourceLocSetup = `
${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
++${coords[rank - 1]};
${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
++${coords[rank - 2]};
${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
--${coords[rank - 1]};
${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
--${coords[rank - 2]};`;
} else {
sourceRank = rank;
sourceLocSetup = `
${dtype} sourceLocR = coords;
++${coords[rank - 1]};
${dtype} sourceLocG = coords;
++${coords[rank - 2]};
${dtype} sourceLocA = coords;
--${coords[rank - 1]};
${dtype} sourceLocB = coords;
--${coords[rank - 2]};`;
}
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
const intChannels = channels.map(x => 'int ' + x);
const srcRCoords =
getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
const srcGCoords =
getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
const srcBCoords =
getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
const srcACoords =
getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
const fetchCandidateIdx = firstPass ? '' : `
inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
getBestIndicesAChannel(${srcGCoords.join()}),
getBestIndicesAChannel(${srcBCoords.join()}),
getBestIndicesAChannel(${srcACoords.join()})));`;
const fetchValue = `vec4(
getAChannel(${srcRCoords.join()}),
hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
const getBestIndicesAChannelSnippet = firstPass ? '' : `
float getBestIndicesAChannel(${intChannels.join()}) {
return getChannel(getBestIndicesA(${channels.join()}),
vec2(${channels.slice(-2).join()}));
}`;
this.userCode = `
float getAChannel(${intChannels.join()}) {
return getChannel(getA(${channels.join()}),
vec2(${channels.slice(-2).join()}));
}
${getBestIndicesAChannelSnippet}
void main() {
${dtype} coords = getOutputCoords();
bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
${sourceLocSetup}
ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
ivec4 inIdx = srcIdx;
vec4 bestIndex = vec4(inIdx);
vec4 bestValue = ${fetchValue};
for (int i = 0; i < ${windowSize}; i++) {
inIdx = srcIdx;
${fetchCandidateIdx}
vec4 candidate = ${fetchValue};
bvec4 nan = isnan(candidate);
bvec4 replace = bvec4(
vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
bestValue = vec4(replace.x ? candidate.x : bestValue.x,
replace.y ? candidate.y : bestValue.y,
replace.z ? candidate.z : bestValue.z,
replace.w ? candidate.w : bestValue.w);
bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
srcIdx++;
}
setOutput(bestIndex);
}
`;
}
}