@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
87 lines (77 loc) • 2.87 kB
text/typescript
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {getChannels} from '../packing_util';
import {GPGPUProgram} from './gpgpu_math';
import {getCoordsDataType} from './shader_compiler';
export class PadPackedProgram implements GPGPUProgram {
variableNames = ['x'];
usesPackedTextures = true;
outputShape: number[];
userCode: string;
constructor(
xShape: number[], paddings: Array<[number, number]>,
constantValue: number) {
this.outputShape = paddings.map(
(p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
const rank = xShape.length;
const dtype = getCoordsDataType(rank);
const start = paddings.map(p => p[0]).join(',');
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
const coords = getChannels('rc', rank);
const source = getChannels('source', rank);
const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
const innerDims =
rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
const componentSetup = [
`${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;
if(${cLimit}) {
`,
rank === 1 ? '' : `}
rc = outputLoc;
${coords[rank - 2]} += 1;
if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,
rank === 1 ? '' : ` ${coords[rank - 1]} += 1;
if(${cLimit}) {`
];
const paddingArea = rank === 1 ?
'rc < start || rc >= end' :
'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
let mainLoop = '';
for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
mainLoop += `
${componentSetup[i]}
if (${paddingArea}) {
result[${i}] = float(${constantValue});
} else {
${dtype} source = rc - start;
result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
}
`;
}
mainLoop += (rank === 1 ? `} ` : `}}`);
this.userCode = `
const ${dtype} start = ${dtype}(${start});
const ${dtype} end = ${dtype}(${end});
void main() {
${dtype} outputLoc = getOutputCoords();
vec4 result = vec4(0.);
${mainLoop}
setOutput(result);
}
`;
}
}