UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

74 lines (60 loc) 2.32 kB
/** * @license * Copyright 2018 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {GPGPUProgram} from './gpgpu_math'; export const COMPLEX_FFT = { REAL: 'return real * expR - imag * expI;', IMAG: 'return real * expI + imag * expR;' }; export class FFTProgram implements GPGPUProgram { variableNames = ['real', 'imag']; outputShape: number[]; userCode: string; constructor(op: string, inputShape: [number, number], inverse: boolean) { const innerDim = inputShape[1]; this.outputShape = inputShape; const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`; const resultDenominator = inverse ? `${innerDim}.0` : '1.0'; this.userCode = ` const float exponentMultiplier = ${exponentMultiplierSnippet}; float unaryOpComplex(float real, float expR, float imag, float expI) { ${op} } float mulMatDFT(int batch, int index) { float indexRatio = float(index) / float(${innerDim}); float exponentMultiplierTimesIndexRatio = exponentMultiplier * indexRatio; float result = 0.0; for (int i = 0; i < ${innerDim}; i++) { // x = (-2|2 * PI / N) * index * i; float x = exponentMultiplierTimesIndexRatio * float(i); float expR = cos(x); float expI = sin(x); float real = getReal(batch, i); float imag = getImag(batch, i); result += unaryOpComplex(real, expR, imag, expI) / ${resultDenominator}; } return result; } void main() { ivec2 coords = getOutputCoords(); setOutput(mulMatDFT(coords[0], coords[1])); } `; } }