UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

302 lines 21.8 kB
import { ENV } from '../../environment'; import * as util from '../../util'; import * as broadcast_util from '../../ops/broadcast_util'; import * as tex_util from './tex_util'; export function makeShader(inputsInfo, outputShape, userCode, broadcast) { var sampleSnippet = getSampleSnippet(); var setOutputSnippet = getSetOutputSnippet(); var inputPrefixSnippet = inputsInfo.map(function (x) { return "uniform sampler2D " + x.name + ";"; }).join('\n'); var inputSamplingSnippet = inputsInfo.map(function (x) { return getInputSamplingSnippet(x, outputShape, broadcast); }) .join('\n'); var outTexShape = outputShape.texShape; var outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape); var source = [ SHADER_PREFIX, sampleSnippet, setOutputSnippet, inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode ].join('\n'); return source; } function getSampleSnippet() { return ENV.get('WEBGL_FLOAT_TEXTURE_ENABLED') ? FLOAT_TEXTURE_SAMPLE_SNIPPET : UNSIGNED_BYTE_TEXTURE_SAMPLE_SNIPPET; } function getSetOutputSnippet() { return ENV.get('WEBGL_FLOAT_TEXTURE_ENABLED') ? FLOAT_TEXTURE_SETOUTPUT_SNIPPET : UNSIGNED_BYTE_TEXTURE_SETOUTPUT_SNIPPET; } function getSamplerFromInInfo(inInfo) { var shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getSamplerScalar(inInfo); case 1: return getSampler1D(inInfo); case 2: return getSampler2D(inInfo); case 3: return getSampler3D(inInfo); case 4: return getSampler4D(inInfo); default: throw new Error(shape.length + "-D input sampling" + " is not yet supported"); } } function getInputSamplingSnippet(inInfo, outShapeInfo, broadcast) { var res = getSamplerFlat(inInfo); res += getSamplerFromInInfo(inInfo); if (broadcast || util.arraysEqual(inInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape)) { res += getSamplerAtOutputCoords(inInfo, outShapeInfo, broadcast); } return res; } function getOutputSamplingSnippet(outShape, outTexShape) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutput1DCoords(outShape, outTexShape); case 2: return getOutput2DCoords(outShape, outTexShape); case 3: return getOutput3DCoords(outShape, outTexShape); case 4: return getOutput4DCoords(outShape, outTexShape); default: throw new Error(outShape.length + "-D output sampling is not yet supported"); } } var SAMPLE_1D_SNIPPET = "\nvec2 UVfrom1D(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var SAMPLE_2D_SNIPPET = "\nvec2 UVfrom2D(int texNumR, int texNumC, int numC, int row, int col) {\n int index = row * numC + col;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var SAMPLE_3D_SNIPPET = "\nvec2 UVfrom3D(int texNumR, int texNumC, int stride0,\n int stride1, int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * stride0 + col * stride1 + depth;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var SAMPLE_4D_SNIPPET = "\nvec2 UVfrom4D(int texNumR, int texNumC, int stride0,\n int stride1, int stride2, int row, int col, int depth,\n int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * stride0 + col * stride1 + depth * stride2 + depth2;\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var UNSIGNED_BYTE_TEXTURE_SAMPLE_SNIPPET = "\n uniform float NaN;\n\n const vec4 floatDeltas = vec4(\n 1.0,\n 1.0 / 255.0,\n 1.0 / (255.0 * 255.0),\n 1.0 / (255.0 * 255.0 * 255.0)\n );\n const float minValue = " + tex_util.FLOAT_MIN + ".0;\n const float maxValue = " + tex_util.FLOAT_MAX + ".0;\n const float range = (maxValue - minValue) / 255.0;\n const vec2 dotRange = vec2(1.0, range);\n\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n vec4 sampleValue = texture2D(textureSampler, uv);\n if (all(equal(sampleValue, vec4(" + tex_util.BYTE_NAN_VALUE + ")))) {\n return NaN;\n }\n\n vec4 encValue = floor(sampleValue * 255.0 + 0.5);\n float decodedValue = dot(encValue, floatDeltas);\n return dot(vec2(minValue, decodedValue), dotRange);\n }\n"; var UNSIGNED_BYTE_TEXTURE_SETOUTPUT_SNIPPET = "\n const vec4 floatPowers = vec4(\n 1.0,\n 255.0,\n 255.0 * 255.0,\n 255.0 * 255.0 * 255.0\n );\n const vec2 recipRange = vec2(1.0/range);\n const vec2 recipRange255 = vec2(1.0/(maxValue - minValue));\n\n void setOutput(float decodedValue) {\n if (isNaN(decodedValue)) {\n gl_FragColor = vec4(" + tex_util.BYTE_NAN_VALUE + ");\n return;\n }\n\n float a = dot(vec2(decodedValue, -minValue), recipRange);\n float b = fract(a) * 255.0;\n float c = fract(b) * 255.0;\n float d = fract(c) * 255.0;\n gl_FragColor = floor(vec4(a, b, c, d)) / 255.0;\n\n // TODO(dsmilkov): Version above gets better accuracy but probably slower\n // than the version below. Benchmark to determine if the accuracy is worth\n // the cost.\n\n // float normValue = dot(vec2(decodedValue, -minValue), recipRange255);\n // vec4 f = normValue * floatPowers;\n // gl_FragColor = floor(fract(f) * 255.0) / 255.0;\n }\n"; var FLOAT_TEXTURE_SAMPLE_SNIPPET = "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return texture2D(textureSampler, uv).r;\n }\n"; var FLOAT_TEXTURE_SETOUTPUT_SNIPPET = "\n void setOutput(float val) {\n gl_FragColor = vec4(val, 0, 0, 0);\n }\n"; var SHADER_PREFIX = "\n precision highp float;\n precision highp int;\n varying vec2 resultUV;\n const vec2 halfCR = vec2(0.5, 0.5);\n\n bool isNaN(float val) {\n float v1 = val * val;\n float v2 = val * val;\n return v1 == v2 ? false : true;\n }\n\n bool hasNaN(vec4 values) {\n vec4 v1 = values * values;\n vec4 v2 = values * values;\n return any(notEqual(v1, v2));\n }\n\n float getNaN(vec4 values) {\n return dot(vec4(1), values);\n }\n\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n " + SAMPLE_4D_SNIPPET + "\n"; function getOutputScalarCoords() { return "\n int getOutputCoords() {\n return 0;\n }\n "; } function getOutput1DCoords(shape, texShape) { if (texShape[0] === 1) { return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n "; } if (texShape[1] === 1) { return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n "; } return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n "; } function getOutput3DCoords(shape, texShape) { var stride0 = shape[1] * shape[2]; var stride1 = shape[2]; return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + stride0 + ";\n index -= r * " + stride0 + ";\n int c = index / " + stride1 + ";\n int d = index - c * " + stride1 + ";\n return ivec3(r, c, d);\n }\n "; } function getOutput4DCoords(shape, texShape) { var stride2 = shape[3]; var stride1 = shape[2] * stride2; var stride0 = shape[1] * stride1; return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n int r = index / " + stride0 + ";\n index -= r * " + stride0 + ";\n\n int c = index / " + stride1 + ";\n index -= c * " + stride1 + ";\n\n int d = index / " + stride2 + ";\n int d2 = index - d * " + stride2 + ";\n\n return ivec4(r, c, d, d2);\n }\n "; } function getOutput2DCoords(shape, texShape) { if (util.arraysEqual(shape, texShape)) { return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n "; } if (shape[1] === 1) { return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n "; } if (shape[0] === 1) { return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n "; } return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n "; } function getSamplerScalar(inputInfo) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n "; } function getSampler1D(inputInfo) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); return "\n float " + funcName + "(int index) {\n return " + funcName + "Flat(index);\n }\n "; } function getSampler2D(inputInfo) { var shape = inputInfo.shapeInfo.logicalShape; var texShape = inputInfo.shapeInfo.texShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texNumR = texShape[0]; var texNumC = texShape[1]; if (util.arraysEqual(shape, texShape)) { return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; var squeezedShape = newShape; if (squeezedShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); var params = ['row', 'col']; return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; } if (texNumC === 1) { return "\n float " + funcName + "(int row, int col) {\n int index = row * " + shape[1] + " + col;\n vec2 uv = vec2(0.5, (float(index) + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } if (texNumR === 1) { return "\n float " + funcName + "(int row, int col) {\n int index = row * " + shape[1] + " + col;\n vec2 uv = vec2((float(index) + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n "; } return "\n float " + funcName + "(int row, int col) {\n vec2 uv = UVfrom2D(" + texNumR + ", " + texNumC + ", " + shape[1] + ", row, col);\n return sampleTexture(" + texName + ", uv);\n }\n"; } function getSampler3D(inputInfo) { var texShape = inputInfo.shapeInfo.texShape; var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texNumR = texShape[0]; var texNumC = texShape[1]; var stride0 = shape[1] * shape[2]; var stride1 = shape[2]; var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; var squeezedShape = newShape; if (squeezedShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); var params = ['row', 'col', 'depth']; return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; } if (texNumC === stride0) { return "\n float " + funcName + "(int row, int col, int depth) {\n int texR = row;\n int texC = col * " + stride1 + " + depth;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } if (texNumC === stride1) { return "\n float " + funcName + "(int row, int col, int depth) {\n int texR = row * " + shape[1] + " + col;\n int texC = depth;\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } return "\n float " + funcName + "(int row, int col, int depth) {\n vec2 uv = UVfrom3D(\n " + texNumR + ", " + texNumC + ", " + stride0 + ", " + stride1 + ", row, col, depth);\n return sampleTexture(" + texName + ", uv);\n }\n "; } function getSampler4D(inputInfo) { var shape = inputInfo.shapeInfo.logicalShape; var texShape = inputInfo.shapeInfo.texShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texNumR = texShape[0]; var texNumC = texShape[1]; var stride2 = shape[3]; var stride1 = shape[2] * stride2; var stride0 = shape[1] * stride1; var _a = util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; if (newShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, newShape); var params = ['row', 'col', 'depth', 'depth2']; return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; } if (texNumC === stride0) { return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int texR = row;\n int texC = col * " + stride1 + " + depth * " + stride2 + " + depth2;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } if (texNumC === stride2) { return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int texR = row * " + shape[1] * shape[2] + " + col * " + shape[2] + " + depth;\n int texC = depth2;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n vec2 uv = UVfrom4D(" + texNumR + ", " + texNumC + ", " + stride0 + ", " + stride1 + ",\n " + stride2 + ", row, col, depth, depth2);\n return sampleTexture(" + texName + ", uv);\n }\n "; } function getSamplerFlat(inputInfo) { var texName = inputInfo.name; var texShape = inputInfo.shapeInfo.texShape; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1) + 'Flat'; var tNumR = texShape[0]; var tNumC = texShape[1]; if (tNumC === 1 && tNumR === 1) { return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n "; } if (tNumC === 1) { return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index) + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; } if (tNumR === 1) { return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index) + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n "; } return "\n float " + funcName + "(int index) {\n vec2 uv = UVfrom1D(" + tNumR + ", " + tNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; } function getBroadcastOutputCoordsSampler(inputInfo, outShapeInfo, texFuncSnippet, funcName) { var inRank = inputInfo.shapeInfo.logicalShape.length; var outRank = outShapeInfo.logicalShape.length; var type = 'int'; if (outRank === 2) { type = 'ivec2'; } else if (outRank === 3) { type = 'ivec3'; } else if (outRank === 4) { type = 'ivec4'; } var broadcastDims = broadcast_util.getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); var rankDiff = outRank - inRank; var coordsSnippet; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(function (d) { return "coords[" + (d + rankDiff) + "] = 0;"; }).join('\n'); } var unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map(function (s, i) { return "coords[" + (i + rankDiff) + "]"; }) .join(', '); } return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n "; } function getSamplerAtOutputCoords(inputInfo, outShapeInfo, supportsBroadcasting) { var inTexShape = inputInfo.shapeInfo.texShape; var texName = inputInfo.name; var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; var broadcastDims = broadcast_util.getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); var inRank = inputInfo.shapeInfo.logicalShape.length; var outRank = outShapeInfo.logicalShape.length; var doBroadcast = supportsBroadcasting && ((outRank > inRank) || broadcastDims.length > 0); var broadcastOverOuter = broadcast_util.broadcastDimsAreOuter(broadcastDims); if (doBroadcast && !broadcastOverOuter) { return getBroadcastOutputCoordsSampler(inputInfo, outShapeInfo, texFuncSnippet, funcName); } var outTexShape = outShapeInfo.texShape; if (util.arraysEqual(inTexShape, outTexShape)) { return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n "; } var inSize = util.sizeFromShape(inTexShape); var broadcastSnippet = ''; if (doBroadcast && broadcastOverOuter) { broadcastSnippet = "\n int mainPart = index / " + inSize + ";\n index -= mainPart * " + inSize + ";\n "; } return "\n float " + funcName + "() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + outTexShape[0] + ", " + outTexShape[1] + "));\n int index = resTexRC.x * " + outTexShape[1] + " + resTexRC.y;\n " + broadcastSnippet + "\n int texR = index / " + inTexShape[1] + ";\n int texC = index - texR * " + inTexShape[1] + ";\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + inTexShape[1] + ".0, " + inTexShape[0] + ".0);\n\n return sampleTexture(" + texName + ", uv);\n }\n "; } export function getCoordsDataType(rank) { if (rank <= 1) { return 'int'; } else if (rank === 2) { return 'ivec2'; } else if (rank === 3) { return 'ivec3'; } else if (rank === 4) { return 'ivec4'; } else { throw Error("GPU for rank " + rank + " is not yet supported"); } } function squeezeInputInfo(inInfo, squeezedShape) { var newInputInfo = JSON.parse(JSON.stringify(inInfo)); newInputInfo.shapeInfo.logicalShape = squeezedShape; return newInputInfo; } function getSqueezedParams(params, keptDims) { return keptDims.map(function (d) { return params[d]; }).join(', '); } //# sourceMappingURL=shader_compiler.js.map