@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
222 lines (170 loc) • 5.57 kB
text/typescript
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as erf_util from '../../ops/erf_util';
import * as selu_util from '../../ops/selu_util';
import {GPGPUProgram} from './gpgpu_math';
export class UnaryOpProgram implements GPGPUProgram {
variableNames = ['A'];
userCode: string;
outputShape: number[];
constructor(aShape: number[], opSnippet: string) {
this.outputShape = aShape;
this.userCode = `
float unaryOperation(float x) {
${opSnippet}
}
void main() {
float x = getAAtOutCoords();
float y = unaryOperation(x);
setOutput(y);
}
`;
}
}
const CHECK_NAN_SNIPPET = `if (isnan(x)) return x;`;
export const LINEAR = `return x;`;
export const ABS = `return abs(x);`;
export const RELU = CHECK_NAN_SNIPPET + `
return (x < 0.0) ? 0.0 : x;
`;
export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
export const SELU = `
// Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
// see: https://arxiv.org/abs/1706.02515
float scaleAlpha = ${selu_util.SELU_SCALEALPHA};
float scale = ${selu_util.SELU_SCALE};
return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
`;
export function STEP(alpha = 0.0) {
return CHECK_NAN_SNIPPET + `
return x > 0.0 ? 1.0 : float(${alpha});
`;
}
export const NEG = `return -x;`;
export const CEIL = `return ceil(x);`;
export const FLOOR = `return floor(x);`;
export const SIGN = `
if (isnan(x)) { return 0.0; }
return sign(x);
`;
export const IS_NAN = `return float(isnan(x));`;
export const IS_INF = `return float(isinf(x));`;
export const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;
export const ROUND = `
// OpenGL ES does not support round function.
// The algorithm is based on banker's rounding.
float base = floor(x);
if ((x - base) < 0.5) {
return floor(x);
} else if ((x - base) > 0.5) {
return ceil(x);
} else {
if (mod(base, 2.0) == 0.0) {
return base;
} else {
return base + 1.0;
}
}
`;
export const EXP = `return exp(x);`;
export const EXPM1 = `return exp(x) - 1.0;`;
export const LOG = `if (x < 0.0) return NAN;
return log(x);`;
export const LOG1P = `return log(1.0 + x);`;
export const SQRT = `return sqrt(x);`;
export const RSQRT = `return inversesqrt(x);`;
export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`;
/**
* mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
*
* epsilon is the difference between 1.0 and the next representable
* float. For a single precision 32 bit float this should be 2^-23, see:
* https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
*
* too_large = (x > -threshold) is value above which exp(x) may overflow
* but softplus(x) == x is within machine epsilon
*
* too_small = (x < threshold) is value below which exp(x) may underflow,
* but softplus(x) == exp(x) is within machine epsilon.
*/
export const SOFTPLUS = `
float epsilon = 1.1920928955078125e-7;
float threshold = log(epsilon) + 2.0;
bool too_large = x > -threshold;
bool too_small = x < threshold;
float result;
float exp_x = exp(x);
if (too_large){
result = x;
}
else if (too_small){
result = exp_x;
}
else{
result = log(exp_x + 1.0);
}
return result;
`;
export const SIN = CHECK_NAN_SNIPPET + `
return sin(x);
`;
export const COS = CHECK_NAN_SNIPPET + `
return cos(x);
`;
export const TAN = `return tan(x);`;
export const ASIN = `return asin(x);`;
export const ACOS = `return acos(x);`;
export const ATAN = CHECK_NAN_SNIPPET + `
return atan(x);
`;
export const SINH = `
float e2x = exp(x);
return (e2x - 1.0 / e2x) / 2.0;
`;
export const COSH = `
float e2x = exp(-x);
return (e2x + 1.0 / e2x) / 2.0;
`;
export const TANH = `
float e2x = exp(-2.0 * abs(x));
return sign(x) * (1.0 - e2x) / (1.0 + e2x);
`;
export const ASINH = `return log(x + sqrt(x * x + 1.0));`;
export const ACOSH = CHECK_NAN_SNIPPET + `
if (x < 1.0) return NAN;
return log(x + sqrt(x * x - 1.0));`;
export const ATANH = CHECK_NAN_SNIPPET + `
if ((x < -1.0) || (x > 1.0)) return NAN;
return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
export const ERF = `
// Error function is calculated approximately with elementary function.
// See "Handbook of Mathematical Functions with Formulas,
// Graphs, and Mathematical Tables", Abramowitz and Stegun.
float p = ${erf_util.ERF_P};
float a1 = ${erf_util.ERF_A1};
float a2 = ${erf_util.ERF_A2};
float a3 = ${erf_util.ERF_A3};
float a4 = ${erf_util.ERF_A4};
float a5 = ${erf_util.ERF_A5};
float t = 1.0 / (1.0 + p * x);
return 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
`;
export const SQUARE = `return x * x;`;
export const RECIPROCAL = `return 1.0 / x;`;
export const LOGICAL_NOT = `return float(!(x >= 1.0));`;
export const TO_INT = `return float(int(x));`;
export const CLONE = 'return x;';