@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
132 lines (112 loc) • 4 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as util from '../../util';
/**
* Produces GLSL code that derives logical coordinates from a flat
* index. The code performs integer division with each stride and decrements
* the index until the index equals the final dimension coordinate.
*/
export function getLogicalCoordinatesFromFlatIndex(
coords: string[], shape: number[], index = 'index'): string {
const strides = util.computeStrides(shape);
return strides
.map((stride, i) => {
const line1 = `int ${coords[i]} = ${index} / ${stride}`;
const line2 = i === strides.length - 1 ?
`int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :
`index -= ${coords[i]} * ${stride}`;
return `${line1}; ${line2};`;
})
.join('');
}
function buildVec(x: string[]): string {
if (x.length === 1) {
return `${x[0]}`;
}
return `vec${x.length}(${x.join(',')})`;
}
/**
* Produces GLSL code that computes the dot product of the input x and y
* vectors. Handles splitting inputs into increments of vec4s when necessary.
*/
export function dotify(x: string[], y: string[]): string {
if (x.length !== y.length) {
throw new Error(
`Vectors to be dotted must be of the same length -` +
`got ${x.length} and ${y.length}`);
}
const slices: string[] = [];
const nearestVec4 = Math.floor(x.length / 4);
const nearestVec4Remainder = x.length % 4;
for (let i = 0; i < nearestVec4; i++) {
const xSlice = x.slice(i * 4, i * 4 + 4);
const ySlice = y.slice(i * 4, i * 4 + 4);
slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
}
if (nearestVec4Remainder !== 0) {
let xSlice = x.slice(nearestVec4 * 4);
let ySlice = y.slice(nearestVec4 * 4);
if (xSlice.length === 1) {
xSlice = xSlice.map(d => `float(${d})`);
ySlice = ySlice.map(d => `float(${d})`);
}
slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`);
}
return slices.map((d, i) => `dot(${d})`).join('+');
}
/**
* Produces GLSL that computes the flat index from 3D coordinates.
*/
export function getFlatIndexFrom3D(shape: [number, number, number]): string {
const strides = util.computeStrides(shape).map(d => d.toString());
return `
int getFlatIndex(ivec3 coords) {
return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
}
`;
}
export const ENCODE_FLOAT_SNIPPET = `
const float FLOAT_MAX = 1.70141184e38;
const float FLOAT_MIN = 1.17549435e-38;
lowp vec4 encode_float(highp float v) {
if (isnan(v)) {
return vec4(255, 255, 255, 255);
}
highp float av = abs(v);
if(av < FLOAT_MIN) {
return vec4(0.0, 0.0, 0.0, 0.0);
} else if(v > FLOAT_MAX) {
return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
} else if(v < -FLOAT_MAX) {
return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
}
highp vec4 c = vec4(0,0,0,0);
highp float e = floor(log2(av));
highp float m = exp2(fract(log2(av))) - 1.0;
c[2] = floor(128.0 * m);
m -= c[2] / 128.0;
c[1] = floor(32768.0 * m);
m -= c[1] / 32768.0;
c[0] = floor(8388608.0 * m);
highp float ebias = e + 127.0;
c[3] = floor(ebias / 2.0);
ebias -= c[3] * 2.0;
c[2] += floor(ebias) * 128.0;
c[3] += 128.0 * step(0.0, -v);
return c / 255.0;
}
`;