UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

144 lines (125 loc) 4.17 kB
/** * @license * Copyright 2017 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as seedrandom from 'seedrandom'; export interface RandomBase { nextValue(): number; } export interface RandNormalDataTypes { float32: Float32Array; int32: Int32Array; } // https://en.wikipedia.org/wiki/Marsaglia_polar_method export class MPRandGauss implements RandomBase { private mean: number; private stdDev: number; private nextVal: number; private dtype?: keyof RandNormalDataTypes; private truncated?: boolean; private upper?: number; private lower?: number; private random: seedrandom.prng; constructor( mean: number, stdDeviation: number, dtype?: keyof RandNormalDataTypes, truncated?: boolean, seed?: number) { this.mean = mean; this.stdDev = stdDeviation; this.dtype = dtype; this.nextVal = NaN; this.truncated = truncated; if (this.truncated) { this.upper = this.mean + this.stdDev * 2; this.lower = this.mean - this.stdDev * 2; } const seedValue = seed ? seed : Math.random(); this.random = seedrandom.alea(seedValue.toString()); } /** Returns next sample from a gaussian distribution. */ public nextValue(): number { if (!isNaN(this.nextVal)) { const value = this.nextVal; this.nextVal = NaN; return value; } let resultX: number, resultY: number; let isValid = false; while (!isValid) { let v1: number, v2: number, s: number; do { v1 = 2 * this.random() - 1; v2 = 2 * this.random() - 1; s = v1 * v1 + v2 * v2; } while (s >= 1 || s === 0); const mul = Math.sqrt(-2.0 * Math.log(s) / s); resultX = this.mean + this.stdDev * v1 * mul; resultY = this.mean + this.stdDev * v2 * mul; if (!this.truncated || this.isValidTruncated(resultX)) { isValid = true; } } if (!this.truncated || this.isValidTruncated(resultY)) { this.nextVal = this.convertValue(resultY); } return this.convertValue(resultX); } /** Handles proper rounding for non floating point numbers. */ private convertValue(value: number): number { if (this.dtype == null || this.dtype === 'float32') { return value; } return Math.round(value); } /** Returns true if less than 2-standard-deviations from the mean. */ private isValidTruncated(value: number): boolean { return value <= this.upper && value >= this.lower; } } export class UniformRandom implements RandomBase { private min: number; private range: number; private random: seedrandom.prng; private dtype?: keyof RandNormalDataTypes; constructor( min = 0, max = 1, dtype?: keyof RandNormalDataTypes, seed?: string|number) { this.min = min; this.range = max - min; this.dtype = dtype; if (seed == null) { seed = Math.random(); } if (typeof seed === 'number') { seed = seed.toString(); } if (!this.canReturnFloat() && this.range <= 1) { throw new Error( `The difference between ${min} - ${max} <= 1 and dtype is not float`); } this.random = seedrandom.alea(seed as string); } /** Handles proper rounding for non floating point numbers. */ private canReturnFloat = () => (this.dtype == null || this.dtype === 'float32'); private convertValue(value: number): number { if (this.canReturnFloat()) { return value; } return Math.round(value); } nextValue() { return this.convertValue(this.min + this.range * this.random()); } }