UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

137 lines (115 loc) 3.63 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {getChannels} from '../packing_util'; import {GPGPUProgram} from './gpgpu_math'; import {getCoordsDataType} from './shader_compiler'; export class PackProgram implements GPGPUProgram { variableNames = ['A']; outputShape: number[]; userCode: string; packedInputs = false; packedOutput = true; constructor( outputShape: number[]) { // TODO(https://github.com/tensorflow/tfjs/issues/893): // Only input / output 3D tensors. this.outputShape = outputShape; const rank = outputShape.length; if (rank === 0) { this.userCode = ` void main() { setOutput(vec4(getA(), 0., 0., 0.)); } `; } else { const channels = getChannels('rc', rank); const dtype = getCoordsDataType(rank); const outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels); const setup = getSetup( rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels); const output = getOutput(outputShape, channels); this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); if(${outOfBoundsCondition}) { setOutput(vec4(0)); } else { ${setup} setOutput(vec4(${output})); } } `; } } } function getSourceCoordsArr(rank: number, dims: string[]): string[] { const coords = []; for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`; for (let d = 2; d < rank; d++) { coord = `${dims[dims.length - 1 - d]},` + coord; } coords.push(coord); } } return coords; } function getOutOfBoundsCondition( rank: number, shape: number[], dims: string[]): string { if (rank === 1) { return `rc > ${shape[0]}`; } let cond = ''; for (let i = rank - 2; i < rank; i++) { cond += `${dims[i]} >= ${shape[i]}`; if (i < rank - 1) { cond += '||'; } } return cond; } function getSetup( rank: number, cols: number, rows: number, dims: string[]): string { if (rank === 1) { return ''; } const innerDims = dims.slice(-2); return ` int r = ${innerDims[0]}; int c = ${innerDims[1]}; int rp1 = r + 1; int cp1 = c + 1; bool cEdge = cp1 >= ${cols}; bool rEdge = rp1 >= ${rows}; `; } function getOutput(shape: number[], dims: string[]): string { const rank = shape.length; const sourceCoords = getSourceCoordsArr(rank, dims); if (rank === 1) { return `getA(rc), rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1), 0, 0`; } return `getA(${sourceCoords[0]}), cEdge ? 0. : getA(${sourceCoords[1]}), rEdge ? 0. : getA(${sourceCoords[2]}), rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`; }