UNPKG

keras-js

Version:

Run Keras models in the browser, with GPU support using WebGL

43 lines (37 loc) 1.36 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.default = conv2dTranspose; function conv2dTranspose(outputShape, inputShape, indexMapShape, useBias, hasFragments) { const addBias = useBias ? `sum += texelFetch(bias, ivec2(out_x, 0), 0).r;` : ''; const adjustIndicesForFragments = hasFragments ? `int fragmentIndex = int(floor(float(rowIndex) / float(${inputShape[0]}))); rowIndex = int(mod(float(rowIndex), float(${inputShape[0]}))); colIndex += fragmentIndex * ${inputShape[1]};` : ''; const source = `#version 300 es precision highp float; precision highp isampler2D; in vec2 outTex; uniform sampler2D matMulResult; uniform isampler2D indexMap; uniform sampler2D bias; out vec4 outColor; void main() { int out_y = int(float(${outputShape[0]}) * outTex.y); int out_x = int(float(${outputShape[1]}) * outTex.x); float sum = 0.; for (int n = 0; n < ${indexMapShape[1]}; ++n) { int index = texelFetch(indexMap, ivec2(n, out_y), 0).r; if (index != -1) { int rowIndex = int(floor(float(index) / float(${inputShape[1]}))); int colIndex = int(mod(float(index), float(${inputShape[1]}))); ${adjustIndicesForFragments} sum += texelFetch(matMulResult, ivec2(colIndex + out_x, rowIndex), 0).r; } } ${addBias} outColor = vec4(sum); } `; return source; }