Skip to content

AFNO Models

spectrans.models.afno

Adaptive Fourier Neural Operator (AFNO) transformer models.

This module implements the AFNO architecture, which performs efficient token mixing by applying learnable transformations in the truncated Fourier domain. AFNO leverages the sparsity of signals in frequency space to achieve computational efficiency while maintaining performance.

The architecture uses adaptive mode truncation - keeping only the most significant Fourier modes and applying MLPs in the frequency domain, reducing computational requirements for long sequences.

Classes:

Name Description
AFNOModel

Complete AFNO model with adaptive Fourier mixing layers.

AFNOEncoder

Encoder-only AFNO for representation learning.

Examples:

Basic AFNO usage for classification:

>>> from spectrans.models.afno import AFNOModel
>>> model = AFNOModel(
...     vocab_size=30000,
...     hidden_dim=768,
...     num_layers=12,
...     max_sequence_length=1024,
...     num_classes=2,
...     modes_seq=256,  # Keep 256 modes in sequence dimension
...     modes_hidden=384  # Keep 384 modes in hidden dimension
... )
>>> input_ids = torch.randint(0, 30000, (8, 1024))
>>> logits = model(input_ids=input_ids)
>>> assert logits.shape == (8, 2)

Using AFNO encoder for long sequences:

>>> from spectrans.models.afno import AFNOEncoder
>>> encoder = AFNOEncoder(
...     hidden_dim=768,
...     num_layers=12,
...     max_sequence_length=4096,
...     modes_seq=512,  # Truncation for efficiency
...     modes_hidden=384
... )
>>> inputs = torch.randn(8, 4096, 768)
>>> features = encoder(inputs_embeds=inputs)
>>> assert features.shape == (8, 4096, 768)

Creating from configuration:

>>> from spectrans.config.models import AFNOModelConfig
>>> config = AFNOModelConfig(
...     hidden_dim=768,
...     num_layers=12,
...     sequence_length=1024,
...     n_modes=256,
...     compression_ratio=0.5
... )
>>> model = AFNOModel.from_config(config)
Notes

Mathematical Foundation ~~~~~~~~~~~~~~~~~~~~~~~

Given input tensor \(\mathbf{X} \in \mathbb{R}^{n \times d}\) where \(n\) is sequence length and \(d\) is hidden dimension, AFNO applies mode-truncated Fourier operations with learnable transformations.

Adaptive Fourier Operation:

The core AFNO operation consists of four steps:

  1. 2D Fourier Transform:

$$ \mathbf{X}{\text{freq}} = \mathcal{F}) $$}(\mathbf{X

where \(\mathbf{X}_{\text{freq}} \in \mathbb{C}^{n \times d}\) is the frequency representation.

  1. Mode Truncation:

$$ \mathbf{X}{\text{trunc}} = \mathbf{X}[0:k_n, 0:k_d] $$}

where \(k_n \ll n\) and \(k_d \ll d\) are the number of retained modes, resulting in \(\mathbf{X}_{\text{trunc}} \in \mathbb{C}^{k_n \times k_d}\).

  1. Frequency Domain MLP:

$$ \mathbf{Y}{\text{freq}} = \text{MLP}(\mathbf{X} $$}}) \odot \mathbf{X}_{\text{trunc}

where \(\odot\) denotes element-wise (Hadamard) multiplication and the MLP operates on complex values with expansion ratio \(r\):

$$ \text{MLP}(\mathbf{z}) = \mathbf{W}_2 \cdot \text{GELU}(\mathbf{W}_1 \mathbf{z} + \mathbf{b}_1) + \mathbf{b}_2 $$

with \(\mathbf{W}_1 \in \mathbb{C}^{rk_d \times k_d}\), \(\mathbf{W}_2 \in \mathbb{C}^{k_d \times rk_d}\).

  1. Zero-padding and Inverse Transform:

$$ \mathbf{Y} = \Re\left(\mathcal{F}{2D}^{-1}(\text{pad}(\mathbf{Y}))\right) $$}

where \(\text{pad}\) zero-pads to original dimensions \(n \times d\) and \(\Re(\cdot)\) takes the real part.

Complete Layer Operations:

For each AFNO layer \(l\), the computation proceeds as:

\[ \mathbf{Z}_l = \mathbf{X}_l + \text{AFNO}(\text{LayerNorm}(\mathbf{X}_l)) \]
\[ \mathbf{X}_{l+1} = \mathbf{Z}_l + \text{FFN}(\text{LayerNorm}(\mathbf{Z}_l)) \]

where FFN follows the same structure as in standard transformers.

Complexity Analysis:

  • Time Complexity: \(O(L \cdot (nd \log(nd) + k_n k_d d))\) where \(L\) is the number of layers
  • Space Complexity: \(O(L \cdot k_n \cdot k_d \cdot d)\)
  • Memory reduction from \(O(nd)\) to \(O(k_n k_d)\) per layer through mode truncation

The mode truncation significantly reduces memory usage, with typical settings using \(k_n = \frac{n}{4}\) and \(k_d = \frac{d}{2}\) achieving 8x memory reduction while maintaining performance.

References

John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).

See Also

spectrans.layers.mixing.afno : AFNO mixing layer implementation. spectrans.layers.operators.fno : Related Fourier Neural Operator implementation. spectrans.models.base : Base model classes.

Classes

AFNOModel

AFNOModel(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 1024, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', gradient_checkpointing: bool = False)

Bases: BaseModel

Adaptive Fourier Neural Operator transformer model.

AFNO performs token mixing using truncated Fourier modes and learnable MLPs in the frequency domain, processing long sequences with \(O(n \log n)\) time complexity.

Parameters:

Name Type Description Default
vocab_size int | None

Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs.

None
hidden_dim int

Hidden dimension size. Default is 768.

768
num_layers int

Number of AFNO layers. Default is 12.

12
max_sequence_length int

Maximum sequence length. Default is 1024.

1024
modes_seq int | None

Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2.

None
modes_hidden int | None

Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2.

None
mlp_ratio float

Expansion ratio for MLP in Fourier domain. Default is 2.0.

2.0
num_classes int | None

Number of output classes for classification. Default is None.

None
use_positional_encoding bool

Whether to use positional encoding. Default is True.

True
positional_encoding_type str

Type of positional encoding: 'sinusoidal' or 'learned'. Default is 'sinusoidal'.

'sinusoidal'
dropout float

Dropout probability. Default is 0.1.

0.1
ffn_hidden_dim int | None

Hidden dimension for FFN. If None, defaults to 4 * hidden_dim.

None
norm_eps float

Epsilon for layer normalization. Default is 1e-12.

1e-12
output_type str

Type of output head: 'classification', 'regression', 'sequence', or 'none'. Default is 'classification'.

'classification'
gradient_checkpointing bool

Whether to use gradient checkpointing. Default is False.

False

Attributes:

Name Type Description
modes_seq int

Number of Fourier modes in sequence dimension.

modes_hidden int

Number of Fourier modes in hidden dimension.

mlp_ratio float

MLP expansion ratio in frequency domain.

blocks ModuleList

List of AFNO transformer blocks.

Methods:

Name Description
build_blocks

Build AFNO transformer blocks with adaptive Fourier mixing.

from_config

Create AFNO model from configuration.

Source code in spectrans/models/afno.py
def __init__(
    self,
    vocab_size: int | None = None,
    hidden_dim: int = 768,
    num_layers: int = 12,
    max_sequence_length: int = 1024,
    modes_seq: int | None = None,
    modes_hidden: int | None = None,
    mlp_ratio: float = 2.0,
    num_classes: int | None = None,
    use_positional_encoding: bool = True,
    positional_encoding_type: PositionalEncodingType = "sinusoidal",
    dropout: float = 0.1,
    ffn_hidden_dim: int | None = None,
    norm_eps: float = 1e-12,
    output_type: OutputHeadType = "classification",
    gradient_checkpointing: bool = False,
):
    self.modes_seq = modes_seq or (max_sequence_length // 2)
    self.modes_hidden = modes_hidden or (hidden_dim // 2)
    self.mlp_ratio = mlp_ratio
    self._dropout_rate = dropout  # Store dropout rate for build_blocks

    super().__init__(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_sequence_length=max_sequence_length,
        num_classes=num_classes,
        use_positional_encoding=use_positional_encoding,
        positional_encoding_type=positional_encoding_type,
        dropout=dropout,
        ffn_hidden_dim=ffn_hidden_dim,
        norm_eps=norm_eps,
        output_type=output_type,
        gradient_checkpointing=gradient_checkpointing,
    )
Functions
build_blocks
build_blocks() -> ModuleList

Build AFNO transformer blocks with adaptive Fourier mixing.

Returns:

Type Description
ModuleList

List of AFNO transformer blocks.

Source code in spectrans/models/afno.py
def build_blocks(self) -> nn.ModuleList:
    """Build AFNO transformer blocks with adaptive Fourier mixing.

    Returns
    -------
    nn.ModuleList
        List of AFNO transformer blocks.
    """
    blocks = []
    for _ in range(self.num_layers):
        # Create AFNO mixing layer
        mixing_layer = AFNOMixing(
            hidden_dim=self.hidden_dim,
            max_sequence_length=self.max_sequence_length,
            modes_seq=self.modes_seq,
            modes_hidden=self.modes_hidden,
            mlp_ratio=self.mlp_ratio,
            dropout=self._dropout_rate,
        )

        # Create AFNO block with pre-normalization
        block = PreNormBlock(
            mixing_layer=mixing_layer,
            hidden_dim=self.hidden_dim,
            ffn_hidden_dim=self.ffn_hidden_dim,
            activation="gelu",
            dropout=self._dropout_rate,
            norm_eps=self.norm_eps,
        )
        blocks.append(block)

    return nn.ModuleList(blocks)
from_config classmethod
from_config(config: AFNOModelConfig) -> AFNOModel

Create AFNO model from configuration.

Parameters:

Name Type Description Default
config AFNOModelConfig

Configuration object with model parameters.

required

Returns:

Type Description
AFNOModel

Configured AFNO model.

Source code in spectrans/models/afno.py
@classmethod
def from_config(cls, config: "AFNOModelConfig") -> "AFNOModel":  # type: ignore[override]
    """Create AFNO model from configuration.

    Parameters
    ----------
    config : AFNOModelConfig
        Configuration object with model parameters.

    Returns
    -------
    AFNOModel
        Configured AFNO model.
    """
    # Handle AFNO-specific mode configuration
    n_modes = getattr(config, "n_modes", None)
    modes_seq = getattr(config, "modes_seq", None)
    modes_hidden = getattr(config, "modes_hidden", None)

    # If n_modes is provided but not modes_seq/modes_hidden, compute them
    if n_modes is not None and modes_seq is None:
        modes_seq = n_modes
    if n_modes is not None and modes_hidden is None:
        compression_ratio = getattr(config, "compression_ratio", 0.5)
        modes_hidden = int(n_modes * compression_ratio)

    return cls(
        vocab_size=getattr(config, "vocab_size", None),
        hidden_dim=config.hidden_dim,
        num_layers=config.num_layers,
        max_sequence_length=config.sequence_length,
        modes_seq=modes_seq,
        modes_hidden=modes_hidden,
        mlp_ratio=getattr(config, "mlp_ratio", 2.0),
        num_classes=getattr(config, "num_classes", None),
        use_positional_encoding=getattr(config, "use_positional_encoding", True),
        positional_encoding_type=getattr(config, "positional_encoding_type", "sinusoidal"),
        dropout=config.dropout,
        ffn_hidden_dim=getattr(config, "ffn_hidden_dim", None),
        norm_eps=getattr(config, "norm_eps", 1e-12),
        output_type=getattr(config, "output_type", "classification"),
        gradient_checkpointing=getattr(config, "gradient_checkpointing", False),
    )

AFNOEncoder

AFNOEncoder(hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 1024, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, gradient_checkpointing: bool = False)

Bases: AFNOModel

Encoder-only AFNO model for representation learning.

This variant of AFNO is designed for tasks that require extracting representations rather than making predictions. It's particularly efficient for processing very long sequences due to the mode truncation.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension size. Default is 768.

768
num_layers int

Number of AFNO layers. Default is 12.

12
max_sequence_length int

Maximum sequence length. Default is 1024.

1024
modes_seq int | None

Number of Fourier modes in sequence dimension.

None
modes_hidden int | None

Number of Fourier modes in hidden dimension.

None
mlp_ratio float

MLP expansion ratio. Default is 2.0.

2.0
use_positional_encoding bool

Whether to use positional encoding. Default is True.

True
positional_encoding_type str

Type of positional encoding. Default is 'sinusoidal'.

'sinusoidal'
dropout float

Dropout probability. Default is 0.1.

0.1
ffn_hidden_dim int | None

Hidden dimension for FFN. If None, defaults to 4 * hidden_dim.

None
norm_eps float

Epsilon for layer normalization. Default is 1e-12.

1e-12
gradient_checkpointing bool

Whether to use gradient checkpointing. Default is False.

False
Source code in spectrans/models/afno.py
def __init__(
    self,
    hidden_dim: int = 768,
    num_layers: int = 12,
    max_sequence_length: int = 1024,
    modes_seq: int | None = None,
    modes_hidden: int | None = None,
    mlp_ratio: float = 2.0,
    use_positional_encoding: bool = True,
    positional_encoding_type: PositionalEncodingType = "sinusoidal",
    dropout: float = 0.1,
    ffn_hidden_dim: int | None = None,
    norm_eps: float = 1e-12,
    gradient_checkpointing: bool = False,
):
    super().__init__(
        vocab_size=None,  # No token embeddings for encoder
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_sequence_length=max_sequence_length,
        modes_seq=modes_seq,
        modes_hidden=modes_hidden,
        mlp_ratio=mlp_ratio,
        num_classes=None,  # No classification head
        use_positional_encoding=use_positional_encoding,
        positional_encoding_type=positional_encoding_type,
        dropout=dropout,
        ffn_hidden_dim=ffn_hidden_dim,
        norm_eps=norm_eps,
        output_type="none",  # Return hidden states
        gradient_checkpointing=gradient_checkpointing,
    )

Functions