UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

132 lines (112 loc) 4 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as util from '../../util'; /** * Produces GLSL code that derives logical coordinates from a flat * index. The code performs integer division with each stride and decrements * the index until the index equals the final dimension coordinate. */ export function getLogicalCoordinatesFromFlatIndex( coords: string[], shape: number[], index = 'index'): string { const strides = util.computeStrides(shape); return strides .map((stride, i) => { const line1 = `int ${coords[i]} = ${index} / ${stride}`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` : `index -= ${coords[i]} * ${stride}`; return `${line1}; ${line2};`; }) .join(''); } function buildVec(x: string[]): string { if (x.length === 1) { return `${x[0]}`; } return `vec${x.length}(${x.join(',')})`; } /** * Produces GLSL code that computes the dot product of the input x and y * vectors. Handles splitting inputs into increments of vec4s when necessary. */ export function dotify(x: string[], y: string[]): string { if (x.length !== y.length) { throw new Error( `Vectors to be dotted must be of the same length -` + `got ${x.length} and ${y.length}`); } const slices: string[] = []; const nearestVec4 = Math.floor(x.length / 4); const nearestVec4Remainder = x.length % 4; for (let i = 0; i < nearestVec4; i++) { const xSlice = x.slice(i * 4, i * 4 + 4); const ySlice = y.slice(i * 4, i * 4 + 4); slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`); } if (nearestVec4Remainder !== 0) { let xSlice = x.slice(nearestVec4 * 4); let ySlice = y.slice(nearestVec4 * 4); if (xSlice.length === 1) { xSlice = xSlice.map(d => `float(${d})`); ySlice = ySlice.map(d => `float(${d})`); } slices.push(`${buildVec(xSlice)}, ${buildVec(ySlice)}`); } return slices.map((d, i) => `dot(${d})`).join('+'); } /** * Produces GLSL that computes the flat index from 3D coordinates. */ export function getFlatIndexFrom3D(shape: [number, number, number]): string { const strides = util.computeStrides(shape).map(d => d.toString()); return ` int getFlatIndex(ivec3 coords) { return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z; } `; } export const ENCODE_FLOAT_SNIPPET = ` const float FLOAT_MAX = 1.70141184e38; const float FLOAT_MIN = 1.17549435e-38; lowp vec4 encode_float(highp float v) { if (isnan(v)) { return vec4(255, 255, 255, 255); } highp float av = abs(v); if(av < FLOAT_MIN) { return vec4(0.0, 0.0, 0.0, 0.0); } else if(v > FLOAT_MAX) { return vec4(0.0, 0.0, 128.0, 127.0) / 255.0; } else if(v < -FLOAT_MAX) { return vec4(0.0, 0.0, 128.0, 255.0) / 255.0; } highp vec4 c = vec4(0,0,0,0); highp float e = floor(log2(av)); highp float m = exp2(fract(log2(av))) - 1.0; c[2] = floor(128.0 * m); m -= c[2] / 128.0; c[1] = floor(32768.0 * m); m -= c[1] / 32768.0; c[0] = floor(8388608.0 * m); highp float ebias = e + 127.0; c[3] = floor(ebias / 2.0); ebias -= c[3] * 2.0; c[2] += floor(ebias) * 128.0; c[3] += 128.0 * step(0.0, -v); return c / 255.0; } `;