Global Filter Mixing¶
spectrans.layers.mixing.global_filter ¶
Global Filter Networks (GFNet) mixing layers.
Implements Global Filter Network mixing layers that apply learnable complex-valued filters in the frequency domain. GFNet provides an alternative to attention by performing element-wise filtering operations in the Fourier domain, maintaining \(O(n \log n)\) complexity while introducing learnable parameters.
Learnable complex filters can selectively emphasize or suppress different frequency components, providing more modeling flexibility than parameter-free Fourier mixing while maintaining computational complexity.
Classes:
| Name | Description |
|---|---|
GlobalFilterMixing |
Basic GFNet global filter with learnable complex filters. |
GlobalFilterMixing2D |
2D variant applying filters to both sequence and feature dimensions. |
AdaptiveGlobalFilter |
Advanced variant with adaptive filter initialization and regularization. |
Examples:
Basic global filter usage:
>>> import torch
>>> from spectrans.layers.mixing.global_filter import GlobalFilterMixing
>>> filter_layer = GlobalFilterMixing(hidden_dim=768, sequence_length=512)
>>> input_seq = torch.randn(32, 512, 768)
>>> output = filter_layer(input_seq)
>>> assert output.shape == input_seq.shape
2D global filtering:
>>> from spectrans.layers.mixing.global_filter import GlobalFilterMixing2D
>>> filter_2d = GlobalFilterMixing2D(hidden_dim=768, sequence_length=512)
>>> output_2d = filter_2d(input_seq)
Adaptive filtering with regularization:
>>> from spectrans.layers.mixing.global_filter import AdaptiveGlobalFilter
>>> adaptive_filter = AdaptiveGlobalFilter(
... hidden_dim=768, sequence_length=512,
... filter_regularization=0.01, adaptive_initialization=True
... )
>>> output_adaptive = adaptive_filter(input_seq)
Notes
Mathematical Foundation:
The Global Filter operation is defined as: $$ \text{GF}(\mathbf{X}) = \mathcal{F}^{-1}(\mathbf{H} \odot \mathcal{F}(\mathbf{X})) $$
Where \(\mathcal{F}\) is FFT along sequence dimension, \(\mathcal{F}^{-1}\) is inverse FFT, \(\mathbf{H} \in \mathbb{C}^{n \times d}\) is a learnable complex filter, and \(\odot\) denotes element-wise (Hadamard) multiplication.
The complex filter is parameterized as: $$ \mathbf{H} = \sigma(\mathbf{W}_r + i\mathbf{W}_i) $$
Where \(\mathbf{W}_r, \mathbf{W}_i \in \mathbb{R}^{n \times d}\) are real-valued learnable parameters, \(\sigma\) is an activation function (typically sigmoid or tanh), and \(i\) is the imaginary unit.
Sigmoid activation provides soft gating with values in \((0,1)\). Tanh provides symmetric activation with values in \((-1,1)\). Identity activation has no transformation but may be unstable.
Time complexity is \(O(nd \log n)\) for FFT operations. Space complexity is \(O(nd)\) for filter parameters and frequency representations. The model uses \(2nd\) real parameters (\(\mathbf{W}_r\) and \(\mathbf{W}_i\)).
Learnable parameters allow task-specific adaptation compared to FNet. Filters can emphasize important frequencies and suppress noise while maintaining linear complexity with added expressiveness. Filter initialization affects training stability. Regularization prevents overfitting to specific frequencies. Activation choice impacts gradient flow and expressiveness.
References
Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, and Jie Zhou. 2021. Global filter networks for image classification. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021), pages 980-993.
See Also
spectrans.layers.mixing.base : Base classes for mixing layers spectrans.layers.mixing.fourier : Parameter-free Fourier mixing layers spectrans.transforms.fourier : Underlying FFT implementations
Classes¶
GlobalFilterMixing ¶
GlobalFilterMixing(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)
Bases: FilterMixingLayer
Global Filter Network mixing layer.
Implements the core GFNet mixing operation with learnable complex filters applied in the frequency domain along the sequence dimension.
The layer uses interpolation to adapt filters to different sequence lengths, processing variable-length inputs while preserving learned frequency patterns. This provides resolution independence compared to fixed-size filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Base sequence length for filter parameter initialization. The filters will be interpolated to match actual input sequence lengths. |
required |
activation
|
ActivationType
|
Activation function applied to filter parameters ("sigmoid", "tanh", "identity"). |
"sigmoid"
|
dropout
|
float
|
Dropout probability applied after filtering. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filter parameters are learnable (always True for this class). |
True
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
filter_init_std
|
float
|
Standard deviation for filter parameter initialization. |
0.02
|
Attributes:
| Name | Type | Description |
|---|---|---|
activation |
str
|
Activation function name. |
filter_real |
Parameter
|
Real part of complex filter parameters. |
filter_imag |
Parameter
|
Imaginary part of complex filter parameters. |
fft1d |
FFT1D
|
1D FFT transform for sequence dimension. |
activation_fn |
Module
|
Activation function module (Sigmoid, Tanh, or Identity). |
Methods:
| Name | Description |
|---|---|
forward |
Apply global filtering to input tensor. |
get_filter_response |
Get the current frequency response of the filters. |
get_spectral_properties |
Get spectral properties of global filtering. |
from_config |
Create GlobalFilterMixing layer from configuration. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply global filtering to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Filtered tensor of same shape as input. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get the current frequency response of the filters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex-valued frequency response of shape (sequence_length, hidden_dim). |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get spectral properties of global filtering.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
Properties including filter characteristics. |
Source code in spectrans/layers/mixing/global_filter.py
from_config
classmethod
¶
from_config(config: GlobalFilterMixingConfig) -> GlobalFilterMixing
Create GlobalFilterMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GlobalFilterMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
GlobalFilterMixing
|
Configured global filter mixing layer. |
Source code in spectrans/layers/mixing/global_filter.py
GlobalFilterMixing2D ¶
GlobalFilterMixing2D(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)
Bases: FilterMixingLayer
2D Global Filter mixing with filtering along both dimensions.
Extends global filtering to both sequence and feature dimensions, similar to FNet's 2D FFT but with learnable filters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Expected sequence length. |
required |
activation
|
ActivationType
|
Activation function for filter parameters. |
"sigmoid"
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable. |
True
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
filter_init_std
|
float
|
Filter parameter initialization standard deviation. |
0.02
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_real |
Parameter
|
Real part of 2D complex filters. |
filter_imag |
Parameter
|
Imaginary part of 2D complex filters. |
fft2d |
FFT2D
|
2D FFT transform module. |
activation_fn |
Module
|
Activation function. |
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D global filtering. |
get_filter_response |
Get 2D frequency response. |
get_spectral_properties |
Get 2D filter properties. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply 2D global filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Filtered tensor of same shape. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get 2D frequency response.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex 2D frequency response. |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get 2D filter properties.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
2D filtering characteristics. |
Source code in spectrans/layers/mixing/global_filter.py
AdaptiveGlobalFilter ¶
AdaptiveGlobalFilter(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02, filter_regularization: float = 0.0, adaptive_initialization: bool = True, spectral_dropout_p: float = 0.0)
Bases: FilterMixingLayer
Adaptive Global Filter with regularization and smart initialization.
Enhanced version of global filtering with adaptive initialization strategies, regularization options, and improved training stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Expected sequence length. |
required |
activation
|
ActivationType
|
Filter activation function. |
"sigmoid"
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Numerical stability epsilon. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable. |
True
|
fft_norm
|
str
|
FFT normalization. |
"ortho"
|
filter_init_std
|
float
|
Filter initialization standard deviation. |
0.02
|
filter_regularization
|
float
|
L2 regularization strength for filter parameters. |
0.0
|
adaptive_initialization
|
bool
|
Whether to use frequency-aware initialization. |
True
|
spectral_dropout_p
|
float
|
Spectral dropout probability in frequency domain. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_regularization |
float
|
Regularization strength. |
adaptive_initialization |
bool
|
Whether adaptive initialization is used. |
spectral_dropout_p |
float
|
Spectral dropout probability. |
spectral_dropout |
Module
|
Spectral dropout layer. |
Methods:
| Name | Description |
|---|---|
forward |
Apply adaptive global filtering. |
get_filter_response |
Get adaptive frequency response. |
get_regularization_loss |
Compute L2 regularization loss for filter parameters. |
get_spectral_properties |
Get adaptive filter properties. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply adaptive global filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Adaptively filtered tensor. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get adaptive frequency response.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex frequency response with current parameters. |
Source code in spectrans/layers/mixing/global_filter.py
get_regularization_loss ¶
Compute L2 regularization loss for filter parameters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar regularization loss. |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get adaptive filter properties.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
Comprehensive properties including adaptive features. |