@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
137 lines (115 loc) • 3.63 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {getChannels} from '../packing_util';
import {GPGPUProgram} from './gpgpu_math';
import {getCoordsDataType} from './shader_compiler';
export class PackProgram implements GPGPUProgram {
variableNames = ['A'];
outputShape: number[];
userCode: string;
packedInputs = false;
packedOutput = true;
constructor(
outputShape:
number[]) { // TODO(https://github.com/tensorflow/tfjs/issues/893):
// Only input / output 3D tensors.
this.outputShape = outputShape;
const rank = outputShape.length;
if (rank === 0) {
this.userCode = `
void main() {
setOutput(vec4(getA(), 0., 0., 0.));
}
`;
} else {
const channels = getChannels('rc', rank);
const dtype = getCoordsDataType(rank);
const outOfBoundsCondition =
getOutOfBoundsCondition(rank, outputShape, channels);
const setup = getSetup(
rank, outputShape[outputShape.length - 1],
outputShape[outputShape.length - 2], channels);
const output = getOutput(outputShape, channels);
this.userCode = `
void main() {
${dtype} rc = getOutputCoords();
if(${outOfBoundsCondition}) {
setOutput(vec4(0));
} else {
${setup}
setOutput(vec4(${output}));
}
}
`;
}
}
}
function getSourceCoordsArr(rank: number, dims: string[]): string[] {
const coords = [];
for (let row = 0; row <= 1; row++) {
for (let col = 0; col <= 1; col++) {
let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
for (let d = 2; d < rank; d++) {
coord = `${dims[dims.length - 1 - d]},` + coord;
}
coords.push(coord);
}
}
return coords;
}
function getOutOfBoundsCondition(
rank: number, shape: number[], dims: string[]): string {
if (rank === 1) {
return `rc > ${shape[0]}`;
}
let cond = '';
for (let i = rank - 2; i < rank; i++) {
cond += `${dims[i]} >= ${shape[i]}`;
if (i < rank - 1) {
cond += '||';
}
}
return cond;
}
function getSetup(
rank: number, cols: number, rows: number, dims: string[]): string {
if (rank === 1) {
return '';
}
const innerDims = dims.slice(-2);
return `
int r = ${innerDims[0]};
int c = ${innerDims[1]};
int rp1 = r + 1;
int cp1 = c + 1;
bool cEdge = cp1 >= ${cols};
bool rEdge = rp1 >= ${rows};
`;
}
function getOutput(shape: number[], dims: string[]): string {
const rank = shape.length;
const sourceCoords = getSourceCoordsArr(rank, dims);
if (rank === 1) {
return `getA(rc),
rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
0, 0`;
}
return `getA(${sourceCoords[0]}),
cEdge ? 0. : getA(${sourceCoords[1]}),
rEdge ? 0. : getA(${sourceCoords[2]}),
rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
}