UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

146 lines (134 loc) 5.76 kB
/** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {assert} from '../../util'; import {getChannels} from '../packing_util'; import {GPGPUProgram} from './gpgpu_math'; import {getCoordsDataType} from './shader_compiler'; export class ArgMinMaxPackedProgram implements GPGPUProgram { variableNames = ['A']; outputShape: number[]; userCode: string; usesPackedTextures = true; constructor( shape: number[], windowSize: number, op: 'max'|'min', firstPass: boolean) { assert( shape.length > 2, () => `Packed arg${ op.charAt(0).toUpperCase() + op.slice(1)} supports only inputs with rank above 2.`); const inSize = shape[shape.length - 1]; const outSize = Math.ceil(inSize / windowSize); this.outputShape = shape.slice(0, -1); if (outSize > 1) { this.outputShape.push(outSize); } if (!firstPass) { this.variableNames.push('bestIndicesA'); } const outShape = this.outputShape; const rank = outShape.length; const dtype = getCoordsDataType(rank); const coords = getChannels('coords', rank); let sourceLocSetup; let sourceRank; if (outSize === 1) { sourceRank = rank + 1; const sourceLocDType = getCoordsDataType(sourceRank); sourceLocSetup = ` ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0); ++${coords[rank - 1]}; ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0); ++${coords[rank - 2]}; ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0); --${coords[rank - 1]}; ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0); --${coords[rank - 2]};`; } else { sourceRank = rank; sourceLocSetup = ` ${dtype} sourceLocR = coords; ++${coords[rank - 1]}; ${dtype} sourceLocG = coords; ++${coords[rank - 2]}; ${dtype} sourceLocA = coords; --${coords[rank - 1]}; ${dtype} sourceLocB = coords; --${coords[rank - 2]};`; } const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3. const intChannels = channels.map(x => 'int ' + x); const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); const compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; const fetchCandidateIdx = firstPass ? '' : ` inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}), getBestIndicesAChannel(${srcGCoords.join()}), getBestIndicesAChannel(${srcBCoords.join()}), getBestIndicesAChannel(${srcACoords.join()})));`; const fetchValue = `vec4( getAChannel(${srcRCoords.join()}), hasNextCol ? getAChannel(${srcGCoords.join()}) : 0., hasNextRow ? getAChannel(${srcBCoords.join()}) : 0., hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`; const getBestIndicesAChannelSnippet = firstPass ? '' : ` float getBestIndicesAChannel(${intChannels.join()}) { return getChannel(getBestIndicesA(${channels.join()}), vec2(${channels.slice(-2).join()})); }`; this.userCode = ` float getAChannel(${intChannels.join()}) { return getChannel(getA(${channels.join()}), vec2(${channels.slice(-2).join()})); } ${getBestIndicesAChannelSnippet} void main() { ${dtype} coords = getOutputCoords(); bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1}; bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1}; ${sourceLocSetup} ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel}, sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize}; ivec4 inIdx = srcIdx; vec4 bestIndex = vec4(inIdx); vec4 bestValue = ${fetchValue}; for (int i = 0; i < ${windowSize}; i++) { inIdx = srcIdx; ${fetchCandidateIdx} vec4 candidate = ${fetchValue}; bvec4 nan = isnan(candidate); bvec4 replace = bvec4( vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan))); bestValue = vec4(replace.x ? candidate.x : bestValue.x, replace.y ? candidate.y : bestValue.y, replace.z ? candidate.z : bestValue.z, replace.w ? candidate.w : bestValue.w); bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace)); srcIdx++; } setOutput(bestIndex); } `; } }