Chapter 14: torch.fft
¶
“When you leave time behind and think in frequencies.”
14.1 What is torch.fft
?¶
The torch.fft
module brings PyTorch into the frequency domain, letting you:
- Decompose signals (like audio, images) into sine/cosine waves
- Denoise data
- Detect periodic patterns
- Power audio processing, computer vision, and even quantum simulations
It’s the PyTorch equivalent of NumPy’s
np.fft
and is built on highly optimized backend code (MKL, CUFFT).
14.2 Forward and Inverse FFT¶
The most basic use case: go to frequency space, and come back.
➤ 1D FFT and IFFT¶
import torch.fft
x = torch.randn(8)
X = torch.fft.fft(x) # Frequency domain (complex numbers)
x_reconstructed = torch.fft.ifft(X) # Back to time domain
X.real
,X.imag
14.3 Real FFTs: rfft()
and irfft()
¶
If your input is real-valued (like audio), use real FFTs for speed:
x = torch.randn(8)
X = torch.fft.rfft(x) # Faster, optimized for real input
x_rec = torch.fft.irfft(X, n=8) # Reconstruct original signal
rfft()
reduces output size by ~50%
- irfft()
needs n (original signal length)
14.4 2D FFTs — Hello, Images¶
Use fft2()
and ifft2()
to process 2D signals (images, heatmaps, etc.)
img = torch.randn(128, 128)
F_img = torch.fft.fft2(img)
img_back = torch.fft.ifft2(F_img).real # Often drop imaginary part
You can even apply frequency masks (e.g., blur, sharpen, edge-detect) directly in frequency space.
14.5 Common Functions in torch.fft¶
Function | Description |
---|---|
fft() | N-point FFT |
ifft() | Inverse FFT |
rfft() | Real-input FFT |
irfft() | Inverse of real FFT |
fft2(), ifft2() | 2D FFT and inverse |
fftn() | N-dimensional FFT |
14.6 Use Cases of FFT in Deep Learning¶
Application | Why FFT? |
---|---|
Audio analysis | Detect pitch, noise, rhythm |
Image filtering | Frequency-based blurs or edges |
Signal denoising | Filter out high-frequency noise |
Physics/finance models T | time-to-frequency domain switching |
Neural net acceleration | Multiply in freq space (FFT Conv) |
Spectral ConvNets? Yep — they multiply weights in the frequency domain.
⚠️ 14.7 Caveats and Complex Tensor Handling¶
- Most FFT results are complex tensors
x = torch.fft.fft(signal) magnitude = x.abs() phase = x.angle()
- ifft() should return back to your original domain — but may differ slightly due to floating-point precision
- rfft() and irfft() require careful dimension tracking
14.8 Example: Denoising a Signal with FFT¶
import torch
# Create noisy sine wave
t = torch.linspace(0, 1, 500)
signal = torch.sin(2 * torch.pi * 5 * t) + 0.5 * torch.randn_like(t)
# FFT
F_signal = torch.fft.fft(signal)
# Zero out high frequencies
F_filtered = F_signal.clone()
F_filtered[50:-50] = 0 # Keep low-frequency content
# IFFT to recover signal
denoised = torch.fft.ifft(F_filtered).real
You just built a basic low-pass filter using PyTorch. 😎
✅ 14.9 Summary¶
Task | Use This |
---|---|
1D signal analysis | fft, rfft, ifft |
Image processing | fft2, ifft2 |
Speed + real input | rfft, irfft |
Custom filters Modify | FFT result, then ifft |
Neural speedups | Spectral convolutions |
-
torch.fft brings NumPy-level spectral power into the PyTorch ecosystem
-
All FFT outputs are complex tensors — handle real/imag wisely
-
Use this for audio, images, denoising, and modeling periodic signals