Mixing Layers¶
spectrans.layers.mixing ¶
Spectral mixing layer implementations for token mixing.
Provides spectral mixing layers as alternatives to attention mechanisms. These layers operate in frequency domains using transforms like FFT, maintaining linear or log-linear computational complexity for token mixing operations.
Mixing layers implement different mathematical approaches including parameter-free Fourier mixing (FNet style), learnable complex filters in frequency domain (GFNet style), and variants with adaptive initialization and multi-dimensional operations.
Modules:
| Name | Description |
|---|---|
afno |
Adaptive Fourier Neural Operator mixing implementations. |
base |
Base classes and interfaces for mixing layers. |
fourier |
Fourier transform-based mixing layers. |
global_filter |
Global filter networks with learnable parameters. |
wavelet |
Wavelet transform-based mixing layers. |
Classes:
| Name | Description |
|---|---|
AFNOMixing |
Adaptive Fourier Neural Operator with mode truncation. |
AdaptiveGlobalFilter |
Enhanced global filter with adaptive initialization. |
FilterMixingLayer |
Base class for learnable frequency domain filters. |
FourierMixing |
2D FFT mixing for both sequence and feature dimensions. |
FourierMixing1D |
1D FFT mixing along sequence dimension only. |
GlobalFilterMixing |
Learnable complex filters in frequency domain. |
GlobalFilterMixing2D |
2D variant with filtering in both dimensions. |
MixingLayer |
Base class for spectral mixing operations. |
RealFourierMixing |
Memory-efficient real FFT variant. |
SeparableFourierMixing |
Configurable sequence and/or feature mixing. |
UnitaryMixingLayer |
Base class for energy-preserving mixing transforms. |
WaveletMixing |
1D wavelet mixing using discrete wavelet transform. |
WaveletMixing2D |
2D wavelet mixing for spatial data processing. |
Examples:
Basic Fourier mixing:
>>> from spectrans.layers.mixing import FourierMixing
>>> mixer = FourierMixing(hidden_dim=768)
>>> output = mixer(input_tensor)
Global filter with learnable parameters:
>>> from spectrans.layers.mixing import GlobalFilterMixing
>>> filter_mixer = GlobalFilterMixing(hidden_dim=768, sequence_length=512)
>>> filtered_output = filter_mixer(input_tensor)
Adaptive filtering:
>>> from spectrans.layers.mixing import AdaptiveGlobalFilter
>>> adaptive_mixer = AdaptiveGlobalFilter(
... hidden_dim=768, sequence_length=512,
... adaptive_initialization=True, filter_regularization=0.01
... )
>>> adaptive_output = adaptive_mixer(input_tensor)
Notes
Complexity Comparison:
Traditional attention has \(O(n^2 d)\) complexity. Fourier mixing reduces this to \(O(nd \log n)\). Global filtering uses \(O(nd \log n)\) complexity plus learnable parameters.
All mixing layers support batch processing with consistent behavior, gradient computation for end-to-end training, shape preservation where output shape equals input shape, and mathematical property verification for energy and orthogonality.
References
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, and Jie Zhou. 2021. Global filter networks for image classification. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021), pages 980-993.
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.layers.mixing.base : Base classes and interfaces.
spectrans.transforms : Underlying spectral transform implementations.
spectrans.blocks : Transformer blocks that use these mixing layers.
Classes¶
AFNOMixing ¶
AFNOMixing(hidden_dim: int, max_sequence_length: int, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0)
Bases: MixingLayer
Adaptive Fourier Neural Operator mixing layer.
This layer performs efficient token mixing by applying learnable transformations in the truncated Fourier domain, significantly reducing computational cost while maintaining model expressiveness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input/output tensors. |
required |
max_sequence_length
|
int
|
Maximum sequence length the model will process. |
required |
modes_seq
|
int | None
|
Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2. |
None
|
modes_hidden
|
int | None
|
Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2. |
None
|
mlp_ratio
|
float
|
Expansion ratio for the MLP in Fourier domain. Default is 2.0. |
2.0
|
activation
|
str
|
Activation function for MLP. Default is 'gelu'. |
'gelu'
|
dropout
|
float
|
Dropout probability for MLP. Default is 0.0. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension size. |
max_sequence_length |
int
|
Maximum supported sequence length. |
modes_seq |
int
|
Number of retained Fourier modes in sequence dimension. |
modes_hidden |
int
|
Number of retained Fourier modes in hidden dimension. |
mlp_ratio |
float
|
MLP expansion ratio. |
fourier_weight |
Parameter
|
Complex-valued learnable weights for Fourier modes. |
mlp |
Sequential
|
MLP applied in Fourier domain. |
Examples:
>>> import torch
>>> layer = AFNOMixing(hidden_dim=768, max_sequence_length=512, modes_seq=128)
>>> x = torch.randn(32, 512, 768)
>>> output = layer(x)
>>> print(output.shape)
torch.Size([32, 512, 768])
Methods:
| Name | Description |
|---|---|
forward |
Apply AFNO mixing to input tensor. |
get_spectral_properties |
Get mathematical properties of AFNO operation. |
from_config |
Create AFNOMixing layer from configuration. |
Source code in spectrans/layers/mixing/afno.py
Functions¶
forward ¶
Apply AFNO mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of same shape as input. |
Source code in spectrans/layers/mixing/afno.py
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | |
get_spectral_properties ¶
Get mathematical properties of AFNO operation.
Returns:
| Type | Description |
|---|---|
dict[str, bool]
|
Mathematical properties of the transform. |
Source code in spectrans/layers/mixing/afno.py
from_config
classmethod
¶
from_config(config: AFNOMixingConfig) -> AFNOMixing
Create AFNOMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AFNOMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
AFNOMixing
|
Configured AFNO mixing layer. |
Source code in spectrans/layers/mixing/afno.py
FilterMixingLayer ¶
FilterMixingLayer(hidden_dim: int, sequence_length: int, dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True)
Bases: MixingLayer
Base class for frequency-domain filtering operations.
Filter mixing layers apply learnable filters in the frequency domain, enabling selective emphasis or suppression of frequency components for improved sequence modeling capabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
sequence_length
|
int
|
Expected sequence length for filter initialization. |
required |
dropout
|
float
|
Dropout probability for regularization. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable parameters. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
sequence_length |
int
|
Expected sequence length. |
learnable_filters |
bool
|
Whether filters are learnable. |
Methods:
| Name | Description |
|---|---|
get_spectral_properties |
Get properties specific to filtering operations. |
get_filter_response |
Get the frequency response of the current filters. |
analyze_frequency_response |
Analyze the frequency response characteristics. |
Source code in spectrans/layers/mixing/base.py
Functions¶
get_spectral_properties ¶
Get properties specific to filtering operations.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing filter-specific properties. |
Source code in spectrans/layers/mixing/base.py
get_filter_response
abstractmethod
¶
Get the frequency response of the current filters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex-valued frequency response of shape matching the filter parameters. |
Source code in spectrans/layers/mixing/base.py
analyze_frequency_response ¶
Analyze the frequency response characteristics.
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary containing: - 'magnitude': Magnitude response - 'phase': Phase response - 'group_delay': Group delay response - 'passband_energy': Energy in different frequency bands |
Source code in spectrans/layers/mixing/base.py
MixingLayer ¶
Bases: SpectralComponent
Base class for spectral mixing operations.
Mixing layers perform token mixing operations using various spectral transforms instead of traditional attention mechanisms. This class provides spectral-specific functionality including mathematical property verification and standardized interfaces for spectral transform operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
dropout
|
float
|
Dropout probability for regularization. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability in normalization. |
1e-5
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension of the model. |
dropout |
Module
|
Dropout layer for regularization. |
norm_eps |
float
|
Epsilon for numerical stability. |
Methods:
| Name | Description |
|---|---|
get_spectral_properties |
Get mathematical properties of the spectral operation. |
verify_shape_consistency |
Verify that input and output shapes are consistent. |
compute_spectral_norm |
Compute spectral norm for analysis and regularization. |
Source code in spectrans/layers/mixing/base.py
Functions¶
get_spectral_properties
abstractmethod
¶
Get mathematical properties of the spectral operation.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing mathematical properties such as: - 'unitary': bool, whether the transform is unitary - 'real_output': bool, whether output is guaranteed real - 'frequency_domain': bool, whether operation occurs in frequency domain - 'energy_preserving': bool, whether energy is preserved |
Source code in spectrans/layers/mixing/base.py
verify_shape_consistency ¶
Verify that input and output shapes are consistent.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_tensor
|
Tensor
|
Input tensor to the mixing layer. |
required |
output_tensor
|
Tensor
|
Output tensor from the mixing layer. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if shapes are consistent, False otherwise. |
Source code in spectrans/layers/mixing/base.py
compute_spectral_norm ¶
Compute spectral norm for analysis and regularization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor to compute spectral norm for. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Spectral norm of the input tensor. |
Source code in spectrans/layers/mixing/base.py
UnitaryMixingLayer ¶
UnitaryMixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001)
Bases: MixingLayer
Base class for unitary mixing operations.
Unitary mixing layers preserve energy and inner products, maintaining mathematical properties essential for stable training and theoretical guarantees in spectral transformers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
dropout
|
float
|
Dropout probability for regularization. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
Attributes:
| Name | Type | Description |
|---|---|---|
energy_tolerance |
float
|
Tolerance for energy preservation checks. |
Methods:
| Name | Description |
|---|---|
get_spectral_properties |
Get properties specific to unitary transforms. |
verify_energy_preservation |
Verify energy preservation (Parseval's theorem). |
verify_orthogonality |
Verify orthogonality of the transform matrix. |
Source code in spectrans/layers/mixing/base.py
Functions¶
get_spectral_properties ¶
Get properties specific to unitary transforms.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing unitary transform properties. |
Source code in spectrans/layers/mixing/base.py
verify_energy_preservation ¶
Verify energy preservation (Parseval's theorem).
Checks that \(||\mathbf{output}||^2 \approx ||\mathbf{input}||^2\) within tolerance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_tensor
|
Tensor
|
Input tensor before transformation. |
required |
output_tensor
|
Tensor
|
Output tensor after transformation. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if energy is preserved within tolerance. |
Source code in spectrans/layers/mixing/base.py
verify_orthogonality ¶
Verify orthogonality of the transform matrix.
Checks that \(\mathbf{T} \mathbf{T}^H \approx \mathbf{I}\) (identity matrix).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transform_matrix
|
Tensor
|
Transform matrix to verify. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if matrix is orthogonal within tolerance. |
Source code in spectrans/layers/mixing/base.py
FourierMixing ¶
FourierMixing(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)
Bases: UnitaryMixingLayer
FNet-style Fourier mixing layer.
Implements the core FNet mixing operation using 2D Fourier transforms along both sequence and feature dimensions, providing an alternative to attention with \(O(n \log n)\) complexity.
The operation performs: 1. 2D FFT across sequence and feature dimensions 2. Optional real part extraction for final output (original FNet behavior) or keep complex values for full information preservation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
dropout
|
float
|
Dropout probability applied after the mixing operation. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations ("forward", "backward", "ortho"). |
"ortho"
|
keep_complex
|
bool
|
If True, keeps complex values from FFT. If False (default), takes only the real part as in original FNet. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
fft2d |
FFT2D
|
2D Fourier transform module. |
keep_complex |
bool
|
Whether to keep complex values or extract real part. |
Methods:
| Name | Description |
|---|---|
forward |
Apply Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of Fourier mixing. |
from_config |
Create FourierMixing layer from configuration. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor of same shape. Complex if keep_complex=True, real values only if keep_complex=False (default). |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties including energy preservation and domain information. |
Source code in spectrans/layers/mixing/fourier.py
from_config
classmethod
¶
from_config(config: FourierMixingConfig) -> FourierMixing
Create FourierMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FourierMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
FourierMixing
|
Configured Fourier mixing layer. |
Source code in spectrans/layers/mixing/fourier.py
FourierMixing1D ¶
FourierMixing1D(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)
Bases: UnitaryMixingLayer
1D Fourier mixing along sequence dimension only.
Applies Fourier transform only along the sequence dimension, preserving feature dimension locality while mixing tokens.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
dropout
|
float
|
Dropout probability applied after the mixing operation. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations. |
"ortho"
|
keep_complex
|
bool
|
If True, keeps complex values from FFT. If False (default), takes only the real part. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
fft1d |
FFT1D
|
1D Fourier transform module. |
keep_complex |
bool
|
Whether to keep complex values or extract real part. |
Methods:
| Name | Description |
|---|---|
forward |
Apply 1D Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of 1D Fourier mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply 1D Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor with Fourier transform applied along sequence dimension. Complex if keep_complex=True, real values only if keep_complex=False. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of 1D Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties specific to 1D sequence mixing. |
Source code in spectrans/layers/mixing/fourier.py
RealFourierMixing ¶
RealFourierMixing(hidden_dim: int, use_real_fft: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')
Bases: UnitaryMixingLayer
Memory-efficient real Fourier mixing.
Uses real FFT operations to exploit Hermitian symmetry, providing ~2x memory and computational savings for real inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
use_real_fft
|
bool
|
Whether to use real FFT for efficiency. |
True
|
dropout
|
float
|
Dropout probability applied after mixing. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations. |
"ortho"
|
Attributes:
| Name | Type | Description |
|---|---|---|
use_real_fft |
bool
|
Whether real FFT is enabled. |
rfft |
RFFT
|
Real FFT transform for sequence dimension. |
rfft2d |
RFFT2D
|
Real 2D FFT transform for both dimensions. |
Methods:
| Name | Description |
|---|---|
forward |
Apply real Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of real Fourier mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply real Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). Should be real-valued for optimal efficiency. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor, guaranteed to be real-valued. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of real Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties including efficiency characteristics. |
Source code in spectrans/layers/mixing/fourier.py
SeparableFourierMixing ¶
SeparableFourierMixing(hidden_dim: int, mix_features: bool = True, mix_sequence: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')
Bases: UnitaryMixingLayer
Separable Fourier mixing with sequence and feature transforms.
Applies separate 1D Fourier transforms along sequence and feature dimensions, which can be more efficient than 2D FFT for certain tensor shapes and provides different mixing characteristics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
mix_features
|
bool
|
Whether to apply FFT along feature dimension. |
True
|
mix_sequence
|
bool
|
Whether to apply FFT along sequence dimension. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
Attributes:
| Name | Type | Description |
|---|---|---|
mix_features |
bool
|
Whether feature mixing is enabled. |
mix_sequence |
bool
|
Whether sequence mixing is enabled. |
fft1d |
FFT1D
|
1D FFT transform module. |
Methods:
| Name | Description |
|---|---|
forward |
Apply separable Fourier mixing. |
get_spectral_properties |
Get properties of separable mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply separable Fourier mixing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor after applying selected transforms. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get properties of separable mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties reflecting the separable nature. |
Source code in spectrans/layers/mixing/fourier.py
AdaptiveGlobalFilter ¶
AdaptiveGlobalFilter(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02, filter_regularization: float = 0.0, adaptive_initialization: bool = True, spectral_dropout_p: float = 0.0)
Bases: FilterMixingLayer
Adaptive Global Filter with regularization and smart initialization.
Enhanced version of global filtering with adaptive initialization strategies, regularization options, and improved training stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Expected sequence length. |
required |
activation
|
ActivationType
|
Filter activation function. |
"sigmoid"
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Numerical stability epsilon. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable. |
True
|
fft_norm
|
str
|
FFT normalization. |
"ortho"
|
filter_init_std
|
float
|
Filter initialization standard deviation. |
0.02
|
filter_regularization
|
float
|
L2 regularization strength for filter parameters. |
0.0
|
adaptive_initialization
|
bool
|
Whether to use frequency-aware initialization. |
True
|
spectral_dropout_p
|
float
|
Spectral dropout probability in frequency domain. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_regularization |
float
|
Regularization strength. |
adaptive_initialization |
bool
|
Whether adaptive initialization is used. |
spectral_dropout_p |
float
|
Spectral dropout probability. |
spectral_dropout |
Module
|
Spectral dropout layer. |
Methods:
| Name | Description |
|---|---|
forward |
Apply adaptive global filtering. |
get_filter_response |
Get adaptive frequency response. |
get_regularization_loss |
Compute L2 regularization loss for filter parameters. |
get_spectral_properties |
Get adaptive filter properties. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply adaptive global filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Adaptively filtered tensor. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get adaptive frequency response.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex frequency response with current parameters. |
Source code in spectrans/layers/mixing/global_filter.py
get_regularization_loss ¶
Compute L2 regularization loss for filter parameters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar regularization loss. |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get adaptive filter properties.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
Comprehensive properties including adaptive features. |
Source code in spectrans/layers/mixing/global_filter.py
GlobalFilterMixing ¶
GlobalFilterMixing(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)
Bases: FilterMixingLayer
Global Filter Network mixing layer.
Implements the core GFNet mixing operation with learnable complex filters applied in the frequency domain along the sequence dimension.
The layer uses interpolation to adapt filters to different sequence lengths, processing variable-length inputs while preserving learned frequency patterns. This provides resolution independence compared to fixed-size filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Base sequence length for filter parameter initialization. The filters will be interpolated to match actual input sequence lengths. |
required |
activation
|
ActivationType
|
Activation function applied to filter parameters ("sigmoid", "tanh", "identity"). |
"sigmoid"
|
dropout
|
float
|
Dropout probability applied after filtering. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filter parameters are learnable (always True for this class). |
True
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
filter_init_std
|
float
|
Standard deviation for filter parameter initialization. |
0.02
|
Attributes:
| Name | Type | Description |
|---|---|---|
activation |
str
|
Activation function name. |
filter_real |
Parameter
|
Real part of complex filter parameters. |
filter_imag |
Parameter
|
Imaginary part of complex filter parameters. |
fft1d |
FFT1D
|
1D FFT transform for sequence dimension. |
activation_fn |
Module
|
Activation function module (Sigmoid, Tanh, or Identity). |
Methods:
| Name | Description |
|---|---|
forward |
Apply global filtering to input tensor. |
get_filter_response |
Get the current frequency response of the filters. |
get_spectral_properties |
Get spectral properties of global filtering. |
from_config |
Create GlobalFilterMixing layer from configuration. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply global filtering to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Filtered tensor of same shape as input. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get the current frequency response of the filters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex-valued frequency response of shape (sequence_length, hidden_dim). |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get spectral properties of global filtering.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
Properties including filter characteristics. |
Source code in spectrans/layers/mixing/global_filter.py
from_config
classmethod
¶
from_config(config: GlobalFilterMixingConfig) -> GlobalFilterMixing
Create GlobalFilterMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GlobalFilterMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
GlobalFilterMixing
|
Configured global filter mixing layer. |
Source code in spectrans/layers/mixing/global_filter.py
GlobalFilterMixing2D ¶
GlobalFilterMixing2D(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)
Bases: FilterMixingLayer
2D Global Filter mixing with filtering along both dimensions.
Extends global filtering to both sequence and feature dimensions, similar to FNet's 2D FFT but with learnable filters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Expected sequence length. |
required |
activation
|
ActivationType
|
Activation function for filter parameters. |
"sigmoid"
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable. |
True
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
filter_init_std
|
float
|
Filter parameter initialization standard deviation. |
0.02
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_real |
Parameter
|
Real part of 2D complex filters. |
filter_imag |
Parameter
|
Imaginary part of 2D complex filters. |
fft2d |
FFT2D
|
2D FFT transform module. |
activation_fn |
Module
|
Activation function. |
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D global filtering. |
get_filter_response |
Get 2D frequency response. |
get_spectral_properties |
Get 2D filter properties. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply 2D global filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Filtered tensor of same shape. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get 2D frequency response.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex 2D frequency response. |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get 2D filter properties.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
2D filtering characteristics. |
Source code in spectrans/layers/mixing/global_filter.py
WaveletMixing ¶
WaveletMixing(hidden_dim: int, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', dropout: float = 0.0)
Bases: Module
Token mixing layer using discrete wavelet transform.
Performs mixing in wavelet domain for multi-resolution processing. Decomposes input using DWT, applies learnable mixing to coefficients, and reconstructs the output with residual connections.
Mathematical Formulation
Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times D}\) where \(B\) is batch size, \(N\) is sequence length, and \(D\) is hidden dimension:
Step 1: Channel-wise Decomposition
For each channel \(d \in \{0, 1, \ldots, D-1\}\), extract the channel signal:
Apply \(J\)-level DWT decomposition:
Where: - \(\mathbf{c}_{A_J}^{(d)} \in \mathbb{R}^{B \times L_{A_J}}\) are approximation coefficients at level \(J\) - \(\mathbf{c}_{D_j}^{(d)} \in \mathbb{R}^{B \times L_{D_j}}\) are detail coefficients at level \(j\) - \(L_{A_J}\) and \(L_{D_j}\) are coefficient lengths after subsampling
Step 2: Learnable Mixing
Apply mixing transformations based on mode:
Pointwise Mixing (:code:mixing_mode='pointwise'):
Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times \max(L) \times D}\) are learnable parameters, and \(\odot\) denotes element-wise multiplication with broadcasting.
Channel Mixing (:code:mixing_mode='channel'):
Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times D \times D}\) are initialized as identity matrices.
Level Mixing (:code:mixing_mode='level'):
Cross-level attention is applied to all coefficients simultaneously:
Step 3: Reconstruction
Reconstruct the signal using inverse DWT:
Apply length adjustment if necessary:
Step 4: Residual Connection and Dropout
Combine all channels and apply residual connection:
Complexity Analysis
-
Time Complexity: \(O(NJ) + O(D \cdot N \log N)\) per forward pass
- \(O(N)\) for DWT/IDWT per level and channel (linear in signal length)
- \(O(DJ)\) for mixing operations across all levels and channels
- Dominated by DWT operations when \(J\) is small
-
Space Complexity: \(O(DN + P)\) where \(P\) is parameter count
- \(O(DN)\) for storing coefficient tensors
- Parameter count depends on mixing mode:
- Pointwise: \(P = O(LD)\) where \(L\) is max coefficient length
- Channel: \(P = O(JD^2)\)
- Level: \(P = O(D^2)\) for attention parameters
Implementation Notes
- Uses PyTorch-native DWT implementation for gradient compatibility
- Dynamic weight slicing ensures proper alignment with variable-length coefficients
- Perfect reconstruction property maintained through careful length handling
- Each channel processed independently for computational efficiency
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size \(D\). |
required |
wavelet
|
str
|
Wavelet type (e.g., 'db1', 'db4', 'sym2'). Determines filter bank characteristics. |
'db4'
|
levels
|
int
|
Number of decomposition levels \(J\). Controls resolution hierarchy. |
3
|
mixing_mode
|
str
|
Mixing strategy: 'pointwise' (element-wise), 'channel' (diagonal), 'level' (attention). |
'pointwise'
|
dropout
|
float
|
Dropout probability applied to mixed coefficients before residual connection. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
dwt |
DWT1D
|
Wavelet transform module implementing PyTorch-native DWT/IDWT. |
mixing_weights |
ParameterDict
|
Learnable parameters for coefficient mixing, structure depends on :attr: |
dropout |
Dropout
|
Dropout layer for regularization. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If :attr: |
Examples:
Basic usage with pointwise mixing:
>>> mixer = WaveletMixing(hidden_dim=256, wavelet='db4', levels=3)
>>> x = torch.randn(32, 128, 256) # (batch, seq_len, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Channel mixing with identity initialization:
>>> mixer = WaveletMixing(hidden_dim=64, mixing_mode='channel', levels=2)
>>> x = torch.randn(16, 64, 64)
>>> output = mixer(x)
>>> # Initially behaves like identity due to residual connection
Cross-level mixing with attention:
>>> mixer = WaveletMixing(hidden_dim=128, mixing_mode='level', levels=4)
>>> x = torch.randn(8, 256, 128)
>>> output = mixer(x) # Attention applied across wavelet levels
Methods:
| Name | Description |
|---|---|
forward |
Apply wavelet-based mixing following the mathematical formulation. |
from_config |
Create WaveletMixing from configuration. |
Source code in spectrans/layers/mixing/wavelet.py
Functions¶
forward ¶
Apply wavelet-based mixing following the mathematical formulation.
Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual. Each hidden dimension is processed independently to maintain channel separability.
Mathematical Implementation
The forward pass implements the mathematical formulation exactly:
- Channel Extraction: \(\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]\) for \(d = 0, \ldots, D-1\)
- Wavelet Decomposition: \(\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}\)
- Learnable Mixing: Apply mode-specific transformations to coefficients
- Signal Reconstruction: \(\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}\)
- Channel Concatenation: \(\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]\)
- Residual Connection: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape \((B, N, D)\) where:
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed output tensor of identical shape \((B, N, D)\) with wavelet-domain mixing applied and residual connection. |
Notes
- Dynamic coefficient length handling ensures robustness to varying sequence lengths
- Perfect reconstruction property maintained through careful padding/truncation
- Gradient flow preserved through PyTorch-native operations
Source code in spectrans/layers/mixing/wavelet.py
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 | |
from_config
classmethod
¶
from_config(config: WaveletMixingConfig) -> WaveletMixing
Create WaveletMixing from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletMixingConfig
|
Typed and validated configuration. |
required |
Returns:
| Type | Description |
|---|---|
WaveletMixing
|
Configured instance. |
Source code in spectrans/layers/mixing/wavelet.py
WaveletMixing2D ¶
WaveletMixing2D(channels: int, wavelet: WaveletType = 'db4', levels: int = 2, mixing_mode: str = 'subband')
Bases: Module
2D wavelet mixing layer for image-like data.
Performs mixing in 2D wavelet domain, suitable for vision transformers and other architectures processing 2D spatial data. Processes spatial information through multi-resolution wavelet subbands.
Mathematical Formulation
Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times C \times H \times W}\) where \(B\) is batch size, \(C\) is channels, \(H\) is height, and \(W\) is width:
Step 1: Channel-wise 2D Decomposition
For each channel \(c \in \{0, 1, \ldots, C-1\}\), extract spatial data:
Apply \(J\)-level 2D DWT decomposition:
Where: - \(\mathbf{LL}_J^{(c)} \in \mathbb{R}^{B \times H_J \times W_J}\) is the approximation subband (low-low) - \(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)} \in \mathbb{R}^{B \times H_j \times W_j}\) are detail subbands - \(H_j = \frac{H}{2^j}\), \(W_j = \frac{W}{2^j}\) are spatial dimensions at level \(j\)
Step 2: Subband Mixing
Apply mixing transformations based on mode:
Subband Mixing (:code:mixing_mode='subband'):
Independent processing of each subband using convolutional networks:
Where \(f_{\cdot}\) are learnable convolutional transformations.
Cross Mixing (:code:mixing_mode='cross'):
Cross-attention across all subbands:
Step 3: 2D Reconstruction
Reconstruct the spatial signal:
Step 4: Channel Concatenation and Residual
Complexity Analysis
- Time Complexity: \(O(CHW \cdot J) + O(\text{mixing operations})\)
- Space Complexity: \(O(CHW + \text{subband storage})\)
Where mixing complexity depends on mode: - Subband: \(O(\text{conv operations per subband})\) - Cross: \(O(\text{attention across subbands})\) - Attention: \(O(\text{transformer encoder})\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
int
|
Number of input/output channels \(C\). |
required |
wavelet
|
str
|
Wavelet type determining 2D filter bank characteristics. |
'db4'
|
levels
|
int
|
Number of decomposition levels \(J\). |
2
|
mixing_mode
|
str
|
Subband mixing strategy: 'subband' (independent), 'cross' (attention), 'attention' (transformer). |
'subband'
|
Attributes:
| Name | Type | Description |
|---|---|---|
dwt |
DWT2D
|
2D wavelet transform module. |
ll_mixer |
Sequential
|
Convolutional network for LL subband (subband mode). |
detail_mixers |
ModuleList
|
Convolutional networks for detail subbands (subband mode). |
cross_mixer |
MultiheadAttention
|
Cross-attention module (cross mode). |
subband_attention |
TransformerEncoder
|
Transformer encoder for subband attention (attention mode). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If :attr: |
Examples:
Independent subband processing:
>>> mixer = WaveletMixing2D(channels=256, wavelet='db4', levels=2)
>>> x = torch.randn(32, 256, 64, 64) # (batch, channels, height, width)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Cross-subband attention:
>>> mixer = WaveletMixing2D(channels=128, mixing_mode='cross', levels=3)
>>> x = torch.randn(16, 128, 128, 128)
>>> output = mixer(x) # Attention applied across all wavelet subbands
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D wavelet-based mixing following the mathematical formulation. |
from_config |
Create WaveletMixing2D from configuration. |
Source code in spectrans/layers/mixing/wavelet.py
Functions¶
forward ¶
Apply 2D wavelet-based mixing following the mathematical formulation.
Implements complete 2D wavelet mixing: spatial decomposition → subband mixing → reconstruction → residual connection. Each channel is processed independently.
Mathematical Implementation
- Channel Extraction: \(\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]\) for each channel \(c\)
- 2D Wavelet Decomposition: \(\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}\)
- Subband Mixing: Apply mode-specific transformations to wavelet subbands
- 2D Reconstruction: \(\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}\)
- Channel Stacking: \(\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]\)
- Residual Connection: \(\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape \((B, C, H, W)\) where:
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed output tensor of identical shape \((B, C, H, W)\) with 2D wavelet-domain mixing applied and residual connection. |
Notes
- Spatial dimensions preserved through careful reconstruction handling
- Different mixing strategies provide various inductive biases
- Subband mode: Independent processing emphasizes local features
- Cross mode: Attention enables global subband interactions
- Attention mode: Full transformer encoder for complex dependencies
Source code in spectrans/layers/mixing/wavelet.py
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 | |
from_config
classmethod
¶
from_config(config: WaveletMixing2DConfig) -> WaveletMixing2D
Create WaveletMixing2D from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletMixing2DConfig
|
Typed and validated configuration. |
required |
Returns:
| Type | Description |
|---|---|
WaveletMixing2D
|
Configured instance. |