AFNO Models¶
spectrans.models.afno ¶
Adaptive Fourier Neural Operator (AFNO) transformer models.
This module implements the AFNO architecture, which performs efficient token mixing by applying learnable transformations in the truncated Fourier domain. AFNO leverages the sparsity of signals in frequency space to achieve computational efficiency while maintaining performance.
The architecture uses adaptive mode truncation - keeping only the most significant Fourier modes and applying MLPs in the frequency domain, reducing computational requirements for long sequences.
Classes:
| Name | Description |
|---|---|
AFNOModel |
Complete AFNO model with adaptive Fourier mixing layers. |
AFNOEncoder |
Encoder-only AFNO for representation learning. |
Examples:
Basic AFNO usage for classification:
>>> from spectrans.models.afno import AFNOModel
>>> model = AFNOModel(
... vocab_size=30000,
... hidden_dim=768,
... num_layers=12,
... max_sequence_length=1024,
... num_classes=2,
... modes_seq=256, # Keep 256 modes in sequence dimension
... modes_hidden=384 # Keep 384 modes in hidden dimension
... )
>>> input_ids = torch.randint(0, 30000, (8, 1024))
>>> logits = model(input_ids=input_ids)
>>> assert logits.shape == (8, 2)
Using AFNO encoder for long sequences:
>>> from spectrans.models.afno import AFNOEncoder
>>> encoder = AFNOEncoder(
... hidden_dim=768,
... num_layers=12,
... max_sequence_length=4096,
... modes_seq=512, # Truncation for efficiency
... modes_hidden=384
... )
>>> inputs = torch.randn(8, 4096, 768)
>>> features = encoder(inputs_embeds=inputs)
>>> assert features.shape == (8, 4096, 768)
Creating from configuration:
>>> from spectrans.config.models import AFNOModelConfig
>>> config = AFNOModelConfig(
... hidden_dim=768,
... num_layers=12,
... sequence_length=1024,
... n_modes=256,
... compression_ratio=0.5
... )
>>> model = AFNOModel.from_config(config)
Notes
Mathematical Foundation ~~~~~~~~~~~~~~~~~~~~~~~
Given input tensor \(\mathbf{X} \in \mathbb{R}^{n \times d}\) where \(n\) is sequence length and \(d\) is hidden dimension, AFNO applies mode-truncated Fourier operations with learnable transformations.
Adaptive Fourier Operation:
The core AFNO operation consists of four steps:
- 2D Fourier Transform:
$$ \mathbf{X}{\text{freq}} = \mathcal{F}) $$}(\mathbf{X
where \(\mathbf{X}_{\text{freq}} \in \mathbb{C}^{n \times d}\) is the frequency representation.
- Mode Truncation:
$$ \mathbf{X}{\text{trunc}} = \mathbf{X}[0:k_n, 0:k_d] $$}
where \(k_n \ll n\) and \(k_d \ll d\) are the number of retained modes, resulting in \(\mathbf{X}_{\text{trunc}} \in \mathbb{C}^{k_n \times k_d}\).
- Frequency Domain MLP:
$$ \mathbf{Y}{\text{freq}} = \text{MLP}(\mathbf{X} $$}}) \odot \mathbf{X}_{\text{trunc}
where \(\odot\) denotes element-wise (Hadamard) multiplication and the MLP operates on complex values with expansion ratio \(r\):
$$ \text{MLP}(\mathbf{z}) = \mathbf{W}_2 \cdot \text{GELU}(\mathbf{W}_1 \mathbf{z} + \mathbf{b}_1) + \mathbf{b}_2 $$
with \(\mathbf{W}_1 \in \mathbb{C}^{rk_d \times k_d}\), \(\mathbf{W}_2 \in \mathbb{C}^{k_d \times rk_d}\).
- Zero-padding and Inverse Transform:
$$ \mathbf{Y} = \Re\left(\mathcal{F}{2D}^{-1}(\text{pad}(\mathbf{Y}))\right) $$}
where \(\text{pad}\) zero-pads to original dimensions \(n \times d\) and \(\Re(\cdot)\) takes the real part.
Complete Layer Operations:
For each AFNO layer \(l\), the computation proceeds as:
where FFN follows the same structure as in standard transformers.
Complexity Analysis:
- Time Complexity: \(O(L \cdot (nd \log(nd) + k_n k_d d))\) where \(L\) is the number of layers
- Space Complexity: \(O(L \cdot k_n \cdot k_d \cdot d)\)
- Memory reduction from \(O(nd)\) to \(O(k_n k_d)\) per layer through mode truncation
The mode truncation significantly reduces memory usage, with typical settings using \(k_n = \frac{n}{4}\) and \(k_d = \frac{d}{2}\) achieving 8x memory reduction while maintaining performance.
References
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.layers.mixing.afno : AFNO mixing layer implementation. spectrans.layers.operators.fno : Related Fourier Neural Operator implementation. spectrans.models.base : Base model classes.
Classes¶
AFNOModel ¶
AFNOModel(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 1024, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', gradient_checkpointing: bool = False)
Bases: BaseModel
Adaptive Fourier Neural Operator transformer model.
AFNO performs token mixing using truncated Fourier modes and learnable MLPs in the frequency domain, processing long sequences with \(O(n \log n)\) time complexity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs. |
None
|
hidden_dim
|
int
|
Hidden dimension size. Default is 768. |
768
|
num_layers
|
int
|
Number of AFNO layers. Default is 12. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. Default is 1024. |
1024
|
modes_seq
|
int | None
|
Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2. |
None
|
modes_hidden
|
int | None
|
Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2. |
None
|
mlp_ratio
|
float
|
Expansion ratio for MLP in Fourier domain. Default is 2.0. |
2.0
|
num_classes
|
int | None
|
Number of output classes for classification. Default is None. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. Default is True. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding: 'sinusoidal' or 'learned'. Default is 'sinusoidal'. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. Default is 0.1. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. Default is 1e-12. |
1e-12
|
output_type
|
str
|
Type of output head: 'classification', 'regression', 'sequence', or 'none'. Default is 'classification'. |
'classification'
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Default is False. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
modes_seq |
int
|
Number of Fourier modes in sequence dimension. |
modes_hidden |
int
|
Number of Fourier modes in hidden dimension. |
mlp_ratio |
float
|
MLP expansion ratio in frequency domain. |
blocks |
ModuleList
|
List of AFNO transformer blocks. |
Methods:
| Name | Description |
|---|---|
build_blocks |
Build AFNO transformer blocks with adaptive Fourier mixing. |
from_config |
Create AFNO model from configuration. |
Source code in spectrans/models/afno.py
Functions¶
build_blocks ¶
Build AFNO transformer blocks with adaptive Fourier mixing.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of AFNO transformer blocks. |
Source code in spectrans/models/afno.py
from_config
classmethod
¶
from_config(config: AFNOModelConfig) -> AFNOModel
Create AFNO model from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AFNOModelConfig
|
Configuration object with model parameters. |
required |
Returns:
| Type | Description |
|---|---|
AFNOModel
|
Configured AFNO model. |
Source code in spectrans/models/afno.py
AFNOEncoder ¶
AFNOEncoder(hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 1024, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, gradient_checkpointing: bool = False)
Bases: AFNOModel
Encoder-only AFNO model for representation learning.
This variant of AFNO is designed for tasks that require extracting representations rather than making predictions. It's particularly efficient for processing very long sequences due to the mode truncation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. Default is 768. |
768
|
num_layers
|
int
|
Number of AFNO layers. Default is 12. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. Default is 1024. |
1024
|
modes_seq
|
int | None
|
Number of Fourier modes in sequence dimension. |
None
|
modes_hidden
|
int | None
|
Number of Fourier modes in hidden dimension. |
None
|
mlp_ratio
|
float
|
MLP expansion ratio. Default is 2.0. |
2.0
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. Default is True. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding. Default is 'sinusoidal'. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. Default is 0.1. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. Default is 1e-12. |
1e-12
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Default is False. |
False
|