Fourier Mixing¶
spectrans.layers.mixing.fourier ¶
Fourier-based mixing layers for spectral transformers.
Implements Fourier-based token mixing mechanisms, including the FNet architecture that replaces attention with two-dimensional Fourier transforms. Implementations follow mathematical formulations from the original papers with PyTorch implementations.
Performs mixing in the frequency domain using Fast Fourier Transforms, which provides \(O(n \log n)\) complexity compared to \(O(n^2)\) for attention while maintaining performance on sequence modeling tasks.
Classes:
| Name | Description |
|---|---|
FourierMixing |
Basic FNet-style Fourier mixing with 2D FFT operations. |
FourierMixing1D |
1D Fourier mixing along sequence dimension only. |
RealFourierMixing |
Memory-efficient variant using real FFT for real-valued inputs. |
Examples:
Basic FNet-style mixing:
>>> import torch
>>> from spectrans.layers.mixing.fourier import FourierMixing
>>> mixer = FourierMixing(hidden_dim=768)
>>> input_seq = torch.randn(32, 512, 768) # (batch, seq_len, hidden)
>>> output = mixer(input_seq)
>>> assert output.shape == input_seq.shape
Memory-efficient real variant:
>>> from spectrans.layers.mixing.fourier import RealFourierMixing
>>> real_mixer = RealFourierMixing(hidden_dim=768, use_real_fft=True)
>>> output_real = real_mixer(input_seq)
1D sequence mixing only:
>>> from spectrans.layers.mixing.fourier import FourierMixing1D
>>> seq_mixer = FourierMixing1D(hidden_dim=768)
>>> output_1d = seq_mixer(input_seq)
Notes
Mathematical Foundation:
The FNet mixing operation is defined as: $$ \text{FourierMix}(\mathbf{X}) = \text{Re}(\mathcal{F}_d^{-1}(\mathcal{F}_n(\mathbf{X}))) $$
Where \(\mathcal{F}_n\) is 1D DFT along sequence dimension, \(\mathcal{F}_d^{-1}\) is inverse 1D DFT along feature dimension, and \(\text{Re}(\cdot)\) denotes real part extraction.
This is implemented using PyTorch's 2D FFT as: $$ \text{FourierMix}(\mathbf{X}) = \text{Re}(\text{fft2d}(\mathbf{X}, \text{dim}=(-2, -1))) $$
Time complexity is \(O(nd \log n + nd \log d) \approx O(nd \log(nd))\) with \(O(nd)\) space for storing frequency domain representations. The real FFT variant exploits Hermitian symmetry for approximately 2x memory and computational savings when inputs are real-valued.
Linear complexity in sequence length contrasts with quadratic complexity for attention. No learnable parameters reduce overfitting risk. Translation equivariance holds in both sequence and feature dimensions with parallelization properties. Content-based interactions are not present (purely positional mixing). Tasks requiring precise positional reasoning may be challenging. Real part extraction can lose information.
References
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
See Also
spectrans.layers.mixing.base : Base classes for mixing operations spectrans.transforms.fourier : Underlying FFT transform implementations
Classes¶
FourierMixing ¶
FourierMixing(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)
Bases: UnitaryMixingLayer
FNet-style Fourier mixing layer.
Implements the core FNet mixing operation using 2D Fourier transforms along both sequence and feature dimensions, providing an alternative to attention with \(O(n \log n)\) complexity.
The operation performs: 1. 2D FFT across sequence and feature dimensions 2. Optional real part extraction for final output (original FNet behavior) or keep complex values for full information preservation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
dropout
|
float
|
Dropout probability applied after the mixing operation. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations ("forward", "backward", "ortho"). |
"ortho"
|
keep_complex
|
bool
|
If True, keeps complex values from FFT. If False (default), takes only the real part as in original FNet. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
fft2d |
FFT2D
|
2D Fourier transform module. |
keep_complex |
bool
|
Whether to keep complex values or extract real part. |
Methods:
| Name | Description |
|---|---|
forward |
Apply Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of Fourier mixing. |
from_config |
Create FourierMixing layer from configuration. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor of same shape. Complex if keep_complex=True, real values only if keep_complex=False (default). |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties including energy preservation and domain information. |
Source code in spectrans/layers/mixing/fourier.py
from_config
classmethod
¶
from_config(config: FourierMixingConfig) -> FourierMixing
Create FourierMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FourierMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
FourierMixing
|
Configured Fourier mixing layer. |
Source code in spectrans/layers/mixing/fourier.py
FourierMixing1D ¶
FourierMixing1D(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)
Bases: UnitaryMixingLayer
1D Fourier mixing along sequence dimension only.
Applies Fourier transform only along the sequence dimension, preserving feature dimension locality while mixing tokens.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
dropout
|
float
|
Dropout probability applied after the mixing operation. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations. |
"ortho"
|
keep_complex
|
bool
|
If True, keeps complex values from FFT. If False (default), takes only the real part. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
fft1d |
FFT1D
|
1D Fourier transform module. |
keep_complex |
bool
|
Whether to keep complex values or extract real part. |
Methods:
| Name | Description |
|---|---|
forward |
Apply 1D Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of 1D Fourier mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply 1D Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor with Fourier transform applied along sequence dimension. Complex if keep_complex=True, real values only if keep_complex=False. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of 1D Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties specific to 1D sequence mixing. |
Source code in spectrans/layers/mixing/fourier.py
RealFourierMixing ¶
RealFourierMixing(hidden_dim: int, use_real_fft: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')
Bases: UnitaryMixingLayer
Memory-efficient real Fourier mixing.
Uses real FFT operations to exploit Hermitian symmetry, providing ~2x memory and computational savings for real inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
use_real_fft
|
bool
|
Whether to use real FFT for efficiency. |
True
|
dropout
|
float
|
Dropout probability applied after mixing. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations. |
"ortho"
|
Attributes:
| Name | Type | Description |
|---|---|---|
use_real_fft |
bool
|
Whether real FFT is enabled. |
rfft |
RFFT
|
Real FFT transform for sequence dimension. |
rfft2d |
RFFT2D
|
Real 2D FFT transform for both dimensions. |
Methods:
| Name | Description |
|---|---|
forward |
Apply real Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of real Fourier mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply real Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). Should be real-valued for optimal efficiency. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor, guaranteed to be real-valued. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of real Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties including efficiency characteristics. |
Source code in spectrans/layers/mixing/fourier.py
SeparableFourierMixing ¶
SeparableFourierMixing(hidden_dim: int, mix_features: bool = True, mix_sequence: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')
Bases: UnitaryMixingLayer
Separable Fourier mixing with sequence and feature transforms.
Applies separate 1D Fourier transforms along sequence and feature dimensions, which can be more efficient than 2D FFT for certain tensor shapes and provides different mixing characteristics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
mix_features
|
bool
|
Whether to apply FFT along feature dimension. |
True
|
mix_sequence
|
bool
|
Whether to apply FFT along sequence dimension. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
Attributes:
| Name | Type | Description |
|---|---|---|
mix_features |
bool
|
Whether feature mixing is enabled. |
mix_sequence |
bool
|
Whether sequence mixing is enabled. |
fft1d |
FFT1D
|
1D FFT transform module. |
Methods:
| Name | Description |
|---|---|
forward |
Apply separable Fourier mixing. |
get_spectral_properties |
Get properties of separable mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply separable Fourier mixing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor after applying selected transforms. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get properties of separable mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties reflecting the separable nature. |