UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

222 lines (170 loc) 5.57 kB
/** * @license * Copyright 2017 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as erf_util from '../../ops/erf_util'; import * as selu_util from '../../ops/selu_util'; import {GPGPUProgram} from './gpgpu_math'; export class UnaryOpProgram implements GPGPUProgram { variableNames = ['A']; userCode: string; outputShape: number[]; constructor(aShape: number[], opSnippet: string) { this.outputShape = aShape; this.userCode = ` float unaryOperation(float x) { ${opSnippet} } void main() { float x = getAAtOutCoords(); float y = unaryOperation(x); setOutput(y); } `; } } const CHECK_NAN_SNIPPET = `if (isnan(x)) return x;`; export const LINEAR = `return x;`; export const ABS = `return abs(x);`; export const RELU = CHECK_NAN_SNIPPET + ` return (x < 0.0) ? 0.0 : x; `; export const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; export const SELU = ` // Stable and Attracting Fixed Point (0, 1) for Normalized Weights. // see: https://arxiv.org/abs/1706.02515 float scaleAlpha = ${selu_util.SELU_SCALEALPHA}; float scale = ${selu_util.SELU_SCALE}; return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0); `; export function STEP(alpha = 0.0) { return CHECK_NAN_SNIPPET + ` return x > 0.0 ? 1.0 : float(${alpha}); `; } export const NEG = `return -x;`; export const CEIL = `return ceil(x);`; export const FLOOR = `return floor(x);`; export const SIGN = ` if (isnan(x)) { return 0.0; } return sign(x); `; export const IS_NAN = `return float(isnan(x));`; export const IS_INF = `return float(isinf(x));`; export const IS_FINITE = `return float(!isnan(x) && !isinf(x));`; export const ROUND = ` // OpenGL ES does not support round function. // The algorithm is based on banker's rounding. float base = floor(x); if ((x - base) < 0.5) { return floor(x); } else if ((x - base) > 0.5) { return ceil(x); } else { if (mod(base, 2.0) == 0.0) { return base; } else { return base + 1.0; } } `; export const EXP = `return exp(x);`; export const EXPM1 = `return exp(x) - 1.0;`; export const LOG = `if (x < 0.0) return NAN; return log(x);`; export const LOG1P = `return log(1.0 + x);`; export const SQRT = `return sqrt(x);`; export const RSQRT = `return inversesqrt(x);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * x));`; /** * mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX * * epsilon is the difference between 1.0 and the next representable * float. For a single precision 32 bit float this should be 2^-23, see: * https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm * * too_large = (x > -threshold) is value above which exp(x) may overflow * but softplus(x) == x is within machine epsilon * * too_small = (x < threshold) is value below which exp(x) may underflow, * but softplus(x) == exp(x) is within machine epsilon. */ export const SOFTPLUS = ` float epsilon = 1.1920928955078125e-7; float threshold = log(epsilon) + 2.0; bool too_large = x > -threshold; bool too_small = x < threshold; float result; float exp_x = exp(x); if (too_large){ result = x; } else if (too_small){ result = exp_x; } else{ result = log(exp_x + 1.0); } return result; `; export const SIN = CHECK_NAN_SNIPPET + ` return sin(x); `; export const COS = CHECK_NAN_SNIPPET + ` return cos(x); `; export const TAN = `return tan(x);`; export const ASIN = `return asin(x);`; export const ACOS = `return acos(x);`; export const ATAN = CHECK_NAN_SNIPPET + ` return atan(x); `; export const SINH = ` float e2x = exp(x); return (e2x - 1.0 / e2x) / 2.0; `; export const COSH = ` float e2x = exp(-x); return (e2x + 1.0 / e2x) / 2.0; `; export const TANH = ` float e2x = exp(-2.0 * abs(x)); return sign(x) * (1.0 - e2x) / (1.0 + e2x); `; export const ASINH = `return log(x + sqrt(x * x + 1.0));`; export const ACOSH = CHECK_NAN_SNIPPET + ` if (x < 1.0) return NAN; return log(x + sqrt(x * x - 1.0));`; export const ATANH = CHECK_NAN_SNIPPET + ` if ((x < -1.0) || (x > 1.0)) return NAN; return (log(1.0 + x) - log(1.0 - x)) / 2.0;`; export const ERF = ` // Error function is calculated approximately with elementary function. // See "Handbook of Mathematical Functions with Formulas, // Graphs, and Mathematical Tables", Abramowitz and Stegun. float p = ${erf_util.ERF_P}; float a1 = ${erf_util.ERF_A1}; float a2 = ${erf_util.ERF_A2}; float a3 = ${erf_util.ERF_A3}; float a4 = ${erf_util.ERF_A4}; float a5 = ${erf_util.ERF_A5}; float t = 1.0 / (1.0 + p * x); return 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); `; export const SQUARE = `return x * x;`; export const RECIPROCAL = `return 1.0 / x;`; export const LOGICAL_NOT = `return float(!(x >= 1.0));`; export const TO_INT = `return float(int(x));`; export const CLONE = 'return x;';