AFNO Mixing¶
spectrans.layers.mixing.afno ¶
Adaptive Fourier Neural Operator (AFNO) mixing layer implementation.
This module provides the AFNO mixing layer, which performs token mixing in the Fourier domain with adaptive mode truncation and learnable spectral filters. AFNO efficiently processes sequence data by operating on truncated Fourier modes, significantly reducing computational complexity while maintaining expressive power.
The AFNO architecture leverages the sparsity of signals in the frequency domain, applying learnable transformations to the most significant Fourier modes while discarding higher-frequency components that often contain noise.
Classes:
| Name | Description |
|---|---|
AFNOMixing |
Adaptive Fourier Neural Operator mixing layer with mode truncation. |
Examples:
Basic AFNO mixing layer:
>>> import torch
>>> from spectrans.layers.mixing.afno import AFNOMixing
>>> layer = AFNOMixing(hidden_dim=768, max_sequence_length=512)
>>> x = torch.randn(32, 512, 768)
>>> output = layer(x)
>>> assert output.shape == x.shape
With custom mode truncation:
>>> # Keep only 25% of Fourier modes
>>> layer = AFNOMixing(
... hidden_dim=768,
... max_sequence_length=512,
... modes_seq=128, # Keep 128 modes in sequence dimension
... modes_hidden=384 # Keep 384 modes in hidden dimension
... )
Notes
The AFNO mixing operation follows the mathematical formulation:
- Apply 2D FFT to input tensor (treating sequence and hidden dims as spatial dims)
- Truncate to keep only low-frequency modes
- Apply learnable MLP to truncated modes
- Zero-pad back to original size and apply inverse FFT
- Add residual connection
The mode truncation significantly reduces memory and computation requirements, making AFNO efficient for long sequences.
References
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.layers.mixing.fourier : Standard Fourier mixing without mode truncation. spectrans.layers.operators.fno : Fourier Neural Operator implementation.
Classes¶
AFNOMixing ¶
AFNOMixing(hidden_dim: int, max_sequence_length: int, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0)
Bases: MixingLayer
Adaptive Fourier Neural Operator mixing layer.
This layer performs efficient token mixing by applying learnable transformations in the truncated Fourier domain, significantly reducing computational cost while maintaining model expressiveness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input/output tensors. |
required |
max_sequence_length
|
int
|
Maximum sequence length the model will process. |
required |
modes_seq
|
int | None
|
Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2. |
None
|
modes_hidden
|
int | None
|
Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2. |
None
|
mlp_ratio
|
float
|
Expansion ratio for the MLP in Fourier domain. Default is 2.0. |
2.0
|
activation
|
str
|
Activation function for MLP. Default is 'gelu'. |
'gelu'
|
dropout
|
float
|
Dropout probability for MLP. Default is 0.0. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension size. |
max_sequence_length |
int
|
Maximum supported sequence length. |
modes_seq |
int
|
Number of retained Fourier modes in sequence dimension. |
modes_hidden |
int
|
Number of retained Fourier modes in hidden dimension. |
mlp_ratio |
float
|
MLP expansion ratio. |
fourier_weight |
Parameter
|
Complex-valued learnable weights for Fourier modes. |
mlp |
Sequential
|
MLP applied in Fourier domain. |
Examples:
>>> import torch
>>> layer = AFNOMixing(hidden_dim=768, max_sequence_length=512, modes_seq=128)
>>> x = torch.randn(32, 512, 768)
>>> output = layer(x)
>>> print(output.shape)
torch.Size([32, 512, 768])
Methods:
| Name | Description |
|---|---|
forward |
Apply AFNO mixing to input tensor. |
get_spectral_properties |
Get mathematical properties of AFNO operation. |
from_config |
Create AFNOMixing layer from configuration. |
Source code in spectrans/layers/mixing/afno.py
Functions¶
forward ¶
Apply AFNO mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of same shape as input. |
Source code in spectrans/layers/mixing/afno.py
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | |
get_spectral_properties ¶
Get mathematical properties of AFNO operation.
Returns:
| Type | Description |
|---|---|
dict[str, bool]
|
Mathematical properties of the transform. |
Source code in spectrans/layers/mixing/afno.py
from_config
classmethod
¶
from_config(config: AFNOMixingConfig) -> AFNOMixing
Create AFNOMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AFNOMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
AFNOMixing
|
Configured AFNO mixing layer. |