@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
414 lines • 30.8 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var broadcast_util = require("../../ops/broadcast_util");
var util = require("../../util");
function makeShader(inputsInfo, outputShape, userCode, broadcast) {
var inputPrefixSnippet = inputsInfo.map(function (x) {
var size = util.sizeFromShape(x.shapeInfo.logicalShape);
if (x.shapeInfo.isUniform) {
return "uniform float " + x.name + (size > 1 ? "[" + size + "]" : '') + ";";
}
return "uniform sampler2D " + x.name + ";";
});
inputPrefixSnippet = inputPrefixSnippet.join('\n');
var inputSamplingSnippet = inputsInfo.map(function (x) { return getInputSamplingSnippet(x, outputShape, broadcast); })
.join('\n');
var outTexShape = outputShape.texShape;
var outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
var source = [
SHADER_PREFIX, FLOAT_TEXTURE_SAMPLE_SNIPPET,
FLOAT_TEXTURE_SETOUTPUT_SNIPPET, inputPrefixSnippet, outputSamplingSnippet,
inputSamplingSnippet, userCode
].join('\n');
return source;
}
exports.makeShader = makeShader;
function getSamplerFromInInfo(inInfo) {
var shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getSamplerScalar(inInfo);
case 1:
return getSampler1D(inInfo);
case 2:
return getSampler2D(inInfo);
case 3:
return getSampler3D(inInfo);
case 4:
return getSampler4D(inInfo);
case 5:
return getSampler5D(inInfo);
case 6:
return getSampler6D(inInfo);
default:
throw new Error(shape.length + "-D input sampling" +
" is not yet supported");
}
}
function getInputSamplingSnippet(inInfo, outShapeInfo, broadcast) {
var res = getSamplerFlat(inInfo);
res += getSamplerFromInInfo(inInfo);
if (broadcast ||
util.arraysEqual(inInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape)) {
res += getSamplerAtOutputCoords(inInfo, outShapeInfo, broadcast);
}
return res;
}
function getOutputSamplingSnippet(outShape, outTexShape) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutput1DCoords(outShape, outTexShape);
case 2:
return getOutput2DCoords(outShape, outTexShape);
case 3:
return getOutput3DCoords(outShape, outTexShape);
case 4:
return getOutput4DCoords(outShape, outTexShape);
case 5:
return getOutput5DCoords(outShape, outTexShape);
case 6:
return getOutput6DCoords(outShape, outTexShape);
default:
throw new Error(outShape.length + "-D output sampling is not yet supported");
}
}
var SAMPLE_1D_SNIPPET = "\nvec2 UVfrom1D(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_2D_SNIPPET = "\nvec2 UVfrom2D(int texNumR, int texNumC, int numC, int row, int col) {\n int index = row * numC + col;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_3D_SNIPPET = "\nvec2 UVfrom3D(int texNumR, int texNumC, int stride0,\n int stride1, int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * stride0 + col * stride1 + depth;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_4D_SNIPPET = "\nvec2 UVfrom4D(int texNumR, int texNumC, int stride0,\n int stride1, int stride2, int row, int col, int depth,\n int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * stride0 + col * stride1 + depth * stride2 + depth2;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_5D_SNIPPET = "\nvec2 UVfrom5D(int texNumR, int texNumC, int stride0,\n int stride1, int stride2, int stride3, int row, int col, int depth,\n int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * stride0 + col * stride1 +\n depth * stride2 + depth2 * stride3 + depth3;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_6D_SNIPPET = "\nvec2 UVfrom6D(int texNumR, int texNumC, int stride0,\n int stride1, int stride2, int stride3, int stride4,\n int row, int col, int depth, int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * stride0 + col * stride1 + depth * stride2 + depth2 *\n stride3 + depth3 * stride4 + depth4;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var FLOAT_TEXTURE_SAMPLE_SNIPPET = "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return texture2D(textureSampler, uv).r;\n }\n";
var FLOAT_TEXTURE_SETOUTPUT_SNIPPET = "\n void setOutput(float val) {\n gl_FragColor = vec4(val, 0, 0, 0);\n }\n";
var SHADER_PREFIX = "\n precision highp float;\n precision highp int;\n varying vec2 resultUV;\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n bool isNaN(float val) {\n return (val < 0.0 || 0.0 < val || val == 0.0) ? false : true;\n }\n\n bool hasNaN(vec4 values) {\n vec4 v1 = values * values;\n vec4 v2 = values * values;\n return any(notEqual(v1, v2));\n }\n\n float getNaN(vec4 values) {\n return dot(vec4(1), values);\n }\n\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n " + SAMPLE_4D_SNIPPET + "\n " + SAMPLE_5D_SNIPPET + "\n " + SAMPLE_6D_SNIPPET + "\n";
function getOutputScalarCoords() {
return "\n int getOutputCoords() {\n return 0;\n }\n ";
}
function getOutput1DCoords(shape, texShape) {
if (texShape[0] === 1) {
return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n ";
}
if (texShape[1] === 1) {
return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n ";
}
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n ";
}
function getOutput3DCoords(shape, texShape) {
var stride0 = shape[1] * shape[2];
var stride1 = shape[2];
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + stride0 + ";\n index -= r * " + stride0 + ";\n int c = index / " + stride1 + ";\n int d = index - c * " + stride1 + ";\n return ivec3(r, c, d);\n }\n ";
}
function getOutput4DCoords(shape, texShape) {
var stride2 = shape[3];
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n int r = index / " + stride0 + ";\n index -= r * " + stride0 + ";\n\n int c = index / " + stride1 + ";\n index -= c * " + stride1 + ";\n\n int d = index / " + stride2 + ";\n int d2 = index - d * " + stride2 + ";\n\n return ivec4(r, c, d, d2);\n }\n ";
}
function getOutput5DCoords(shape, texShape) {
var stride3 = shape[4];
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n int r = index / " + stride0 + ";\n index -= r * " + stride0 + ";\n\n int c = index / " + stride1 + ";\n index -= c * " + stride1 + ";\n\n int d = index / " + stride2 + ";\n index -= d * " + stride2 + ";\n\n int d2 = index / " + stride3 + ";\n int d3 = index - d2 * " + stride3 + ";\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n ";
}
function getOutput6DCoords(shape, texShape) {
var stride4 = shape[5];
var stride3 = shape[4] * stride4;
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n int r = index / " + stride0 + ";\n index -= r * " + stride0 + ";\n\n int c = index / " + stride1 + ";\n index -= c * " + stride1 + ";\n\n int d = index / " + stride2 + ";\n index -= d * " + stride2 + ";\n\n int d2 = index / " + stride3 + ";\n index -= d2 * " + stride3 + ";\n\n int d3 = index / " + stride4 + ";\n int d4 = index - d3 * " + stride4 + ";\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n ";
}
function getOutput2DCoords(shape, texShape) {
if (util.arraysEqual(shape, texShape)) {
return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n ";
}
if (shape[1] === 1) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
}
if (shape[0] === 1) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n ";
}
function getSamplerScalar(inputInfo) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
return "float " + funcName + "() {return " + texName + ";}";
}
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
}
function getSampler1D(inputInfo) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
return "\n float " + funcName + "(int index) {\n return " + funcName + "Flat(index);\n }\n ";
}
function getSampler2D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
if (texShape != null && util.arraysEqual(shape, texShape)) {
var texNumR_1 = texShape[0];
var texNumC_1 = texShape[1];
return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC_1 + ".0, " + texNumR_1 + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
var squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ['row', 'col'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col) {\n int index = row * " + shape[1] + " + col;\n return " + funcName + "Flat(index);\n }\n ";
}
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === 1) {
return "\n float " + funcName + "(int row, int col) {\n int index = row * " + shape[1] + " + col;\n vec2 uv = vec2(0.5, (float(index) + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumR === 1) {
return "\n float " + funcName + "(int row, int col) {\n int index = row * " + shape[1] + " + col;\n vec2 uv = vec2((float(index) + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col) {\n vec2 uv = UVfrom2D(" + texNumR + ", " + texNumC + ", " + shape[1] + ", row, col);\n return sampleTexture(" + texName + ", uv);\n }\n";
}
function getSampler3D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride0 = shape[1] * shape[2];
var stride1 = shape[2];
var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
var squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ['row', 'col', 'depth'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth) {\n int index = row * " + stride0 + " + col * " + stride1 + " + depth;\n return " + funcName + "Flat(index);\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0) {
return "\n float " + funcName + "(int row, int col, int depth) {\n int texR = row;\n int texC = col * " + stride1 + " + depth;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride1) {
return "\n float " + funcName + "(int row, int col, int depth) {\n int texR = row * " + shape[1] + " + col;\n int texC = depth;\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth) {\n vec2 uv = UVfrom3D(\n " + texNumR + ", " + texNumC + ", " + stride0 + ", " + stride1 + ", row, col, depth);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler4D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride2 = shape[3];
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ['row', 'col', 'depth', 'depth2'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n return " + funcName + "Flat(index);\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int texR = row;\n int texC = col * " + stride1 + " + depth * " + stride2 + " + depth2;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride2) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int texR = row * " + shape[1] * shape[2] + " + col * " + shape[2] + " + depth;\n int texC = depth2;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n vec2 uv = UVfrom4D(" + texNumR + ", " + texNumC + ", " + stride0 + ", " + stride1 + ",\n " + stride2 + ", row, col, depth, depth2);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler5D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride3 = shape[4];
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ['row', 'col', 'depth', 'depth2', 'depth3'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2 * " + stride3 + " + depth3;\n return " + funcName + "Flat(index);\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n int texC = col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride3) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row * " + shape[1] * shape[2] + " + col * " + shape[2] + " +\n depth * " + shape[3] + " + depth2;\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n vec2 uv = UVfrom5D(" + texNumR + ", " + texNumC + ", " + stride0 + ", " + stride1 + ",\n " + stride2 + ", " + stride3 + ", row, col, depth, depth2, depth3);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler6D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride4 = shape[5];
var stride3 = shape[4] * stride4;
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2 * " + stride3 + " + depth3 * " + stride3 + "\n + depth4\n return " + funcName + "Flat(index);\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0) {
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n int texC = col * " + stride1 + " + depth * " + stride2 + " + depth2;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride4) {
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row * " + shape[1] * shape[2] + " + col * " + shape[2] + " + depth;\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n vec2 uv = UVfrom6D(" + texNumR + ", " + texNumC + ", " + stride0 + ", " + stride1 + ",\n " + stride2 + ", " + stride3 + ", " + stride4 + "\n ,row, col, depth, depth2, depth3, depth4);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSamplerFlat(inputInfo) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1) + 'Flat';
var inSize = util.sizeFromShape(inputInfo.shapeInfo.logicalShape);
if (inputInfo.shapeInfo.isUniform) {
if (inSize === 1) {
return "float " + funcName + "(int index) {return " + texName + ";}";
}
return "\n float " + funcName + "(int index) {\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var tNumR = texShape[0];
var tNumC = texShape[1];
if (tNumC === 1 && tNumR === 1) {
return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
}
if (tNumC === 1) {
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index) + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (tNumR === 1) {
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index) + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int index) {\n vec2 uv = UVfrom1D(" + tNumR + ", " + tNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getBroadcastOutputCoordsSampler(inputInfo, outShapeInfo, texFuncSnippet, funcName) {
var inRank = inputInfo.shapeInfo.logicalShape.length;
var outRank = outShapeInfo.logicalShape.length;
var type = 'int';
if (outRank === 2) {
type = 'ivec2';
}
else if (outRank === 3) {
type = 'ivec3';
}
else if (outRank === 4) {
type = 'ivec4';
}
var broadcastDims = broadcast_util.getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
var rankDiff = outRank - inRank;
var coordsSnippet;
if (inRank === 0) {
coordsSnippet = '';
}
else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
}
else {
coordsSnippet =
broadcastDims.map(function (d) { return "coords[" + (d + rankDiff) + "] = 0;"; }).join('\n');
}
var unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
}
else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
.map(function (s, i) { return "coords[" + (i + rankDiff) + "]"; })
.join(', ');
}
return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n ";
}
function getSamplerAtOutputCoords(inputInfo, outShapeInfo, supportsBroadcasting) {
var texName = inputInfo.name;
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
var broadcastDims = broadcast_util.getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
var inRank = inputInfo.shapeInfo.logicalShape.length;
var outRank = outShapeInfo.logicalShape.length;
var doBroadcast = supportsBroadcasting && ((outRank > inRank) || broadcastDims.length > 0);
var broadcastOverOuter = broadcast_util.broadcastDimsAreOuter(broadcastDims);
var isUniform = inputInfo.shapeInfo.isUniform;
if (doBroadcast && !broadcastOverOuter) {
return getBroadcastOutputCoordsSampler(inputInfo, outShapeInfo, texFuncSnippet, funcName);
}
var inSize = util.sizeFromShape(inputInfo.shapeInfo.logicalShape);
var broadcastSnippet = '';
if (doBroadcast && broadcastOverOuter) {
broadcastSnippet = "\n int mainPart = index / " + inSize + ";\n index -= mainPart * " + inSize + ";\n ";
}
var outTexShape = outShapeInfo.texShape;
if (isUniform) {
if (inSize === 1) {
return "float " + funcName + "() {return " + texName + ";}";
}
return "\n float " + funcName + "() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + outTexShape[0] + ", " + outTexShape[1] + "));\n int index = resTexRC.x * " + outTexShape[1] + " + resTexRC.y;\n " + broadcastSnippet + "\n return get" + texFuncSnippet + "Flat(index);\n }\n ";
}
var inTexShape = inputInfo.shapeInfo.texShape;
if (util.arraysEqual(inTexShape, outTexShape)) {
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n ";
}
return "\n float " + funcName + "() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + outTexShape[0] + ", " + outTexShape[1] + "));\n int index = resTexRC.x * " + outTexShape[1] + " + resTexRC.y;\n " + broadcastSnippet + "\n int texR = index / " + inTexShape[1] + ";\n int texC = index - texR * " + inTexShape[1] + ";\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + inTexShape[1] + ".0, " + inTexShape[0] + ".0);\n\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getCoordsDataType(rank) {
if (rank <= 1) {
return 'int';
}
else if (rank === 2) {
return 'ivec2';
}
else if (rank === 3) {
return 'ivec3';
}
else if (rank === 4) {
return 'ivec4';
}
else if (rank === 5) {
return 'ivec5';
}
else if (rank === 6) {
return 'ivec6';
}
else {
throw Error("GPU for rank " + rank + " is not yet supported");
}
}
exports.getCoordsDataType = getCoordsDataType;
function squeezeInputInfo(inInfo, squeezedShape) {
var newInputInfo = JSON.parse(JSON.stringify(inInfo));
newInputInfo.shapeInfo.logicalShape = squeezedShape;
return newInputInfo;
}
function getSqueezedParams(params, keptDims) {
return keptDims.map(function (d) { return params[d]; }).join(', ');
}
//# sourceMappingURL=shader_compiler.js.map