UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

210 lines (187 loc) 7.03 kB
/** * @license * Copyright 2018 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {ENGINE} from '../engine'; import {complex, imag, real} from '../ops/complex_ops'; import {op} from '../ops/operation'; import {Tensor, Tensor2D} from '../tensor'; import {assert} from '../util'; import {scalar, zeros} from './tensor_ops'; /** * Fast Fourier transform. * * Computes the 1-dimensional discrete Fourier transform over the inner-most * dimension of input. * * ```js * const real = tf.tensor1d([1, 2, 3]); * const imag = tf.tensor1d([1, 2, 3]); * const x = tf.complex(real, imag); * * x.fft().print(); // tf.spectral.fft(x).print(); * ``` * @param input The complex input to compute an fft over. */ /** * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} */ function fft_(input: Tensor): Tensor { assert( input.dtype === 'complex64', () => `The dtype for tf.spectral.fft() must be complex64 ` + `but got ${input.dtype}.`); // Collapse all outer dimensions to a single batch dimension. const innerDimensionSize = input.shape[input.shape.length - 1]; const batch = input.size / innerDimensionSize; const input2D = input.as2D(batch, innerDimensionSize); const ret = ENGINE.runKernel(backend => backend.fft(input2D), {input}); return ret.reshape(input.shape); } /** * Inverse fast Fourier transform. * * Computes the inverse 1-dimensional discrete Fourier transform over the * inner-most dimension of input. * * ```js * const real = tf.tensor1d([1, 2, 3]); * const imag = tf.tensor1d([1, 2, 3]); * const x = tf.complex(real, imag); * * x.ifft().print(); // tf.spectral.ifft(x).print(); * ``` * @param input The complex input to compute an ifft over. */ /** * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} */ function ifft_(input: Tensor): Tensor { assert( input.dtype === 'complex64', () => `The dtype for tf.spectral.ifft() must be complex64 ` + `but got ${input.dtype}.`); // Collapse all outer dimensions to a single batch dimension. const innerDimensionSize = input.shape[input.shape.length - 1]; const batch = input.size / innerDimensionSize; const input2D = input.as2D(batch, innerDimensionSize); const ret = ENGINE.runKernel(backend => backend.ifft(input2D), {input}); return ret.reshape(input.shape); } /** * Real value input fast Fourier transform. * * Computes the 1-dimensional discrete Fourier transform over the * inner-most dimension of the real input. * * ```js * const real = tf.tensor1d([1, 2, 3]); * * real.rfft().print(); * ``` * @param input The real value input to compute an rfft over. */ /** * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} */ function rfft_(input: Tensor, fftLength?: number): Tensor { assert( input.dtype === 'float32', () => `The dtype for rfft() must be real value but got ${input.dtype}`); let innerDimensionSize = input.shape[input.shape.length - 1]; const batch = input.size / innerDimensionSize; let adjustedInput: Tensor; if (fftLength != null && fftLength < innerDimensionSize) { // Need to crop const begin = input.shape.map(v => 0); const size = input.shape.map(v => v); size[input.shape.length - 1] = fftLength; adjustedInput = input.slice(begin, size); innerDimensionSize = fftLength; } else if (fftLength != null && fftLength > innerDimensionSize) { // Need to pad with zeros const zerosShape = input.shape.map(v => v); zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize; adjustedInput = input.concat(zeros(zerosShape), input.shape.length - 1); innerDimensionSize = fftLength; } else { adjustedInput = input; } // Complement the input with zero imaginary numbers. const zerosInput = adjustedInput.zerosLike(); const complexInput = complex(adjustedInput, zerosInput).as2D(batch, innerDimensionSize); const ret = fft(complexInput); // Exclude complex conjugations. These conjugations are put symmetrically. const half = Math.floor(innerDimensionSize / 2) + 1; const realValues = real(ret); const imagValues = imag(ret); const realComplexConjugate = realValues.split( [half, innerDimensionSize - half], realValues.shape.length - 1); const imagComplexConjugate = imagValues.split( [half, innerDimensionSize - half], imagValues.shape.length - 1); const outputShape = adjustedInput.shape.slice(); outputShape[adjustedInput.shape.length - 1] = half; return complex(realComplexConjugate[0], imagComplexConjugate[0]) .reshape(outputShape); } /** * Inversed real value input fast Fourier transform. * * Computes the 1-dimensional inversed discrete Fourier transform over the * inner-most dimension of the real input. * * ```js * const real = tf.tensor1d([1, 2, 3]); * const imag = tf.tensor1d([0, 0, 0]); * const x = tf.complex(real, imag); * * x.irfft().print(); * ``` * @param input The real value input to compute an irfft over. */ /** * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} */ function irfft_(input: Tensor): Tensor { const innerDimensionSize = input.shape[input.shape.length - 1]; const batch = input.size / innerDimensionSize; if (innerDimensionSize <= 2) { const complexInput = input.as2D(batch, innerDimensionSize); const ret = ifft(complexInput); return real(ret); } else { // The length of unique components of the DFT of a real-valued signal // is 2 * (input_len - 1) const outputShape = [batch, 2 * (innerDimensionSize - 1)]; const realInput = real(input).as2D(batch, innerDimensionSize); const imagInput = imag(input).as2D(batch, innerDimensionSize); const realConjugate = realInput.slice([0, 1], [batch, innerDimensionSize - 2]).reverse(1); const imagConjugate = imagInput.slice([0, 1], [batch, innerDimensionSize - 2]) .reverse(1) .mul(scalar(-1)) as Tensor2D; const r = realInput.concat(realConjugate, 1); const i = imagInput.concat(imagConjugate, 1); const complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]); const ret = ifft(complexInput); return real(ret); } } export const fft = op({fft_}); export const ifft = op({ifft_}); export const rfft = op({rfft_}); export const irfft = op({irfft_});