Fourier Transforms¶
spectrans.transforms.fourier ¶
Fourier transform implementations for spectral neural networks.
This module provides Fourier transform implementations for spectral transformer architectures. The transforms are built on PyTorch's native FFT operations for GPU acceleration and automatic differentiation support.
All Fourier transforms in this module are unitary, preserving complex inner products and maintaining energy conservation (Parseval's theorem). They support various normalization modes and handle both real and complex inputs efficiently.
Classes:
| Name | Description |
|---|---|
FFT1D |
1D Fast Fourier Transform with configurable normalization. |
FFT2D |
2D Fast Fourier Transform for AFNO-style 2D operations. |
RFFT |
Real-input Fast Fourier Transform. |
RFFT2D |
2D Real-input Fast Fourier Transform. |
SpectralPooling |
Spectral domain pooling operation via frequency truncation. |
Examples:
Basic 1D FFT usage:
>>> import torch
>>> from spectrans.transforms.fourier import FFT1D
>>> fft = FFT1D(norm='ortho')
>>> signal = torch.randn(32, 512, dtype=torch.complex64)
>>> freq_domain = fft.transform(signal, dim=-1)
>>> reconstructed = fft.inverse_transform(freq_domain, dim=-1)
Real-input FFT:
>>> from spectrans.transforms.fourier import RFFT
>>> rfft = RFFT(norm='ortho')
>>> real_signal = torch.randn(32, 512)
>>> freq_domain = rfft.transform(real_signal) # Returns complex output
>>> # Note: inverse returns real values for real-input FFTs
2D FFT for AFNO operations:
>>> from spectrans.transforms.fourier import FFT2D
>>> fft2d = FFT2D(norm='ortho')
>>> tensor_2d = torch.randn(32, 64, 64, dtype=torch.complex64)
>>> freq_2d = fft2d.transform(tensor_2d, dim=(-2, -1))
Spectral pooling for downsampling:
>>> from spectrans.transforms.fourier import SpectralPooling
>>> pool = SpectralPooling(output_size=256, input_size=512)
>>> downsampled = pool.transform(freq_domain)
Notes
Mathematical Properties:
Fourier transforms with 'ortho' normalization maintain unitarity:
- Energy conservation (ortho mode): \(\|\mathcal{F}(\mathbf{x})\|^2 = \|\mathbf{x}\|^2\)
- Parseval's theorem: \(\langle \mathbf{x}, \mathbf{y} \rangle = \langle \mathcal{F}(\mathbf{x}), \overline{\mathcal{F}(\mathbf{y})} \rangle\)
- Perfect reconstruction: \(\mathcal{F}^{-1}(\mathcal{F}(\mathbf{x})) = \mathbf{x}\)
Normalization Modes:
- 'forward': No scaling on forward transform, \(\frac{1}{n}\) scaling on inverse
- 'backward': \(\frac{1}{n}\) scaling on forward transform, no scaling on inverse
- 'ortho': \(\frac{1}{\sqrt{n}}\) scaling on both directions (unitary)
The 'ortho' mode is recommended for neural networks as it preserves numerical stability and maintains consistent scaling throughout the network.
Real-Input FFT: RFFT and RFFT2D exploit Hermitian symmetry of real-input FFTs, storing only the non-redundant frequency components for real-valued inputs.
GPU Acceleration: All transforms utilize PyTorch's cuFFT backend when tensors are on GPU.
Gradient Support: All transforms support automatic differentiation through PyTorch's autograd system, enabling end-to-end training of spectral neural networks.
References
James W. Cooley and John W. Tukey. 1965. An algorithm for the machine calculation of complex Fourier series. Mathematics of Computation, 19(90):297-301.
Michael T. Heideman, Don H. Johnson, and C. Sidney Burrus. 1984. Gauss and the history of the fast Fourier transform. IEEE ASSP Magazine, 1(4):14-21.
Steven G. Johnson and Matteo Frigo. 2007. A modified split-radix FFT with fewer arithmetic operations. IEEE Transactions on Signal Processing, 55(1):111-119.
See Also
spectrans.transforms.base : Base classes for transform interfaces spectrans.utils.complex : Complex tensor utility functions spectrans.layers.mixing.fourier : Neural layers using these transforms
Classes¶
FFT1D ¶
Bases: UnitaryTransform
1D Fast Fourier Transform.
Applies 1D FFT along a specified dimension of the input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
norm
|
FFTNorm
|
Normalization mode: "forward", "backward", or "ortho". |
"ortho"
|
Methods:
| Name | Description |
|---|---|
transform |
Apply 1D FFT. |
inverse_transform |
Apply inverse 1D FFT. |
Source code in spectrans/transforms/fourier.py
Functions¶
transform ¶
Apply 1D FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of real or complex values. |
required |
dim
|
int
|
Dimension along which to apply FFT. |
-1
|
Returns:
| Type | Description |
|---|---|
ComplexTensor
|
Complex-valued FFT result. |
Source code in spectrans/transforms/fourier.py
inverse_transform ¶
Apply inverse 1D FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ComplexTensor
|
Complex-valued FFT coefficients. |
required |
dim
|
int
|
Dimension along which to apply inverse FFT. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Inverse FFT result (may be complex if input was complex). |
Source code in spectrans/transforms/fourier.py
FFT2D ¶
Bases: SpectralTransform2D
2D Fast Fourier Transform.
Applies 2D FFT along the last two dimensions of the input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
norm
|
FFTNorm
|
Normalization mode: "forward", "backward", or "ortho". |
"ortho"
|
Methods:
| Name | Description |
|---|---|
transform |
Apply 2D FFT. |
inverse_transform |
Apply inverse 2D FFT. |
Source code in spectrans/transforms/fourier.py
Functions¶
transform ¶
Apply 2D FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of real or complex values. |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply 2D FFT. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
ComplexTensor
|
Complex-valued 2D FFT result. |
Source code in spectrans/transforms/fourier.py
inverse_transform ¶
Apply inverse 2D FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ComplexTensor
|
Complex-valued FFT coefficients. |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply inverse FFT. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Inverse FFT result. |
Source code in spectrans/transforms/fourier.py
RFFT ¶
Bases: UnitaryTransform
Real Fast Fourier Transform.
Applies FFT to real-valued inputs, returning only the positive frequency components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
norm
|
FFTNorm
|
Normalization mode: "forward", "backward", or "ortho". |
"ortho"
|
Methods:
| Name | Description |
|---|---|
transform |
Apply real FFT. |
inverse_transform |
Apply inverse real FFT. |
Source code in spectrans/transforms/fourier.py
Functions¶
transform ¶
Apply real FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Real-valued input tensor. |
required |
dim
|
int
|
Dimension along which to apply RFFT. |
-1
|
Returns:
| Type | Description |
|---|---|
ComplexTensor
|
Complex-valued RFFT result (positive frequencies only). |
Source code in spectrans/transforms/fourier.py
inverse_transform ¶
Apply inverse real FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ComplexTensor
|
Complex-valued RFFT coefficients. |
required |
dim
|
int
|
Dimension along which to apply inverse RFFT. |
-1
|
n
|
int | None
|
Length of the output signal. If None, inferred from input. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Real-valued inverse RFFT result. |
Source code in spectrans/transforms/fourier.py
RFFT2D ¶
Bases: SpectralTransform2D
2D Real Fast Fourier Transform.
Applies 2D FFT to real-valued inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
norm
|
FFTNorm
|
Normalization mode: "forward", "backward", or "ortho". |
"ortho"
|
Methods:
| Name | Description |
|---|---|
transform |
Apply 2D real FFT. |
inverse_transform |
Apply inverse 2D real FFT. |
Source code in spectrans/transforms/fourier.py
Functions¶
transform ¶
Apply 2D real FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Real-valued input tensor. |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply 2D RFFT. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
ComplexTensor
|
Complex-valued 2D RFFT result. |
Source code in spectrans/transforms/fourier.py
inverse_transform ¶
inverse_transform(x: ComplexTensor, dim: tuple[int, int] = (-2, -1), s: tuple[int, int] | None = None) -> Tensor
Apply inverse 2D real FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ComplexTensor
|
Complex-valued RFFT coefficients. |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply inverse RFFT. |
(-2, -1)
|
s
|
tuple[int, int] | None
|
Output signal size. If None, inferred from input. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Real-valued inverse RFFT result. |
Source code in spectrans/transforms/fourier.py
SpectralPooling ¶
Bases: UnitaryTransform
Spectral pooling via frequency domain truncation.
Reduces spatial dimensions by truncating high-frequency components in the Fourier domain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_size
|
int | tuple[int, ...]
|
Target output size after pooling. |
required |
norm
|
FFTNorm
|
Normalization mode for FFT operations. |
"ortho"
|
Methods:
| Name | Description |
|---|---|
transform |
Apply spectral pooling. |
inverse_transform |
Inverse is not well-defined for pooling operations. |
Source code in spectrans/transforms/fourier.py
Functions¶
transform ¶
Apply spectral pooling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor to pool. |
required |
dim
|
int | tuple[int, ...]
|
Dimensions to pool along. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Spectrally pooled tensor. |
Source code in spectrans/transforms/fourier.py
inverse_transform ¶
Inverse is not well-defined for pooling operations.