UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

87 lines (77 loc) 2.87 kB
/** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {getChannels} from '../packing_util'; import {GPGPUProgram} from './gpgpu_math'; import {getCoordsDataType} from './shader_compiler'; export class PadPackedProgram implements GPGPUProgram { variableNames = ['x']; usesPackedTextures = true; outputShape: number[]; userCode: string; constructor( xShape: number[], paddings: Array<[number, number]>, constantValue: number) { this.outputShape = paddings.map( (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const componentSetup = [ `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1; if(${cLimit}) { `, rank === 1 ? '' : `} rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`, rank === 1 ? '' : ` ${coords[rank - 1]} += 1; if(${cLimit}) {` ]; const paddingArea = rank === 1 ? 'rc < start || rc >= end' : 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))'; let mainLoop = ''; for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) { mainLoop += ` ${componentSetup[i]} if (${paddingArea}) { result[${i}] = float(${constantValue}); } else { ${dtype} source = rc - start; result[${i}] = getChannel(getX(${source.join()}), ${innerDims}); } `; } mainLoop += (rank === 1 ? `} ` : `}}`); this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } }