Skip to content

FNO Transformer

spectrans.models.fno_transformer

Fourier Neural Operator (FNO) transformer models.

This module implements transformer models based on the Fourier Neural Operator, which learns mappings between function spaces by parameterizing integral kernels in the Fourier domain. These models achieve \(O(n \log n)\) complexity through FFT operations and are particularly effective for learning solution operators.

The FNO mechanism learns integral operators by parameterizing convolution kernels in the Fourier domain, enabling efficient global interactions through spectral truncation and complex-valued weight multiplication.

Classes:

Name Description
FNOTransformer

Complete transformer model using Fourier Neural Operators.

FNOEncoder

Encoder-only model for representation learning with FNO.

FNODecoder

Decoder model with causal FNO support for generation tasks.

Examples:

Basic FNO transformer:

>>> import torch
>>> from spectrans.models.fno_transformer import FNOTransformer
>>> model = FNOTransformer(
...     hidden_dim=512,
...     num_layers=6,
...     modes=32,
...     max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512)  # (batch, seq_len, dim)
>>> output = model(inputs_embeds=x)
>>> assert output.shape == x.shape

Using with token inputs and classification:

>>> model = FNOTransformer(
...     vocab_size=10000,
...     hidden_dim=512,
...     num_layers=6,
...     modes=16,
...     num_classes=10,
...     max_sequence_length=512
... )
>>> input_ids = torch.randint(0, 10000, (32, 100))
>>> logits = model(input_ids)
>>> assert logits.shape == (32, 10)

2D FNO for image-like sequence data:

>>> from spectrans.models.fno_transformer import FNOTransformer
>>> model = FNOTransformer(
...     hidden_dim=512,
...     num_layers=6,
...     modes=32,
...     use_2d=True,
...     spatial_dim=64,  # Sequence viewed as 64x64 spatial grid
...     max_sequence_length=4096
... )
Notes

Mathematical Foundation:

The FNO learns operators between function spaces through integral transforms:

\[ (K \ast v)(x) = \int k(x, y) v(y) dy \]

In the Fourier domain, convolution becomes multiplication:

\[ \mathcal{F}[K \ast v] = R_{\theta} \cdot \mathcal{F}[v] \]

Where \(R_{\theta}\) are learnable complex weights truncated to the lowest \(k\) frequency modes:

\[ R_{\theta} \in \mathbb{C}^{k \times d_{in} \times d_{out}} \]

The spectral convolution is computed as:

  1. Forward FFT: \(\hat{v} = \mathcal{F}[v]\)
  2. Mode truncation: Keep only lowest \(k\) modes
  3. Complex multiplication: \(\hat{u}_k = R_{\theta,k} \cdot \hat{v}_k\)
  4. Inverse FFT: \(u = \mathcal{F}^{-1}[\hat{u}]\)

This achieves \(O(n \log n)\) complexity while learning global dependencies through the spectral parameterization.

References

Zongyi Li, Nikola Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhattacharya, Andrew Stuart, and Anima Anandkumar. 2021. Fourier neural operator for parametric partial differential equations. In Proceedings of the International Conference on Learning Representations (ICLR).

John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).

See Also

spectrans.layers.operators.fno : Core FNO layer implementations. spectrans.transforms.fourier : FFT operations used by FNO.

Classes

FNOTransformer

FNOTransformer(vocab_size: int | None = None, hidden_dim: int = 512, num_layers: int = 6, max_sequence_length: int = 1024, modes: int = 32, mlp_ratio: float = 2.0, use_2d: bool = False, spatial_dim: int | None = None, num_classes: int | None = None, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)

Bases: BaseModel

Fourier Neural Operator transformer model.

This model uses Fourier Neural Operators for sequence mixing, achieving O(n log n) complexity through FFT operations. The model learns mappings between function spaces by parameterizing kernels in the Fourier domain.

Parameters:

Name Type Description Default
vocab_size int | None

Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs.

None
hidden_dim int

Hidden dimension size for the model.

512
num_layers int

Number of transformer blocks.

6
max_sequence_length int

Maximum sequence length the model can process.

1024
modes int

Number of Fourier modes to retain (frequency truncation).

32
mlp_ratio float

Expansion ratio for the MLP in FNO blocks.

2.0
use_2d bool

Whether to use 2D spectral convolutions for spatial data.

False
spatial_dim int | None

Spatial dimension when using 2D convolutions (sequence = spatial_dim²).

None
num_classes int | None

Number of output classes for classification.

None
ffn_hidden_dim int | None

Hidden dimension of the feedforward network. Default is 4 * hidden_dim.

None
dropout float

Dropout probability.

0.0
use_positional_encoding bool

Whether to use positional encoding.

True
positional_encoding_type str

Type of positional encoding ("sinusoidal" or "learned").

"sinusoidal"
gradient_checkpointing bool

Whether to use gradient checkpointing to save memory.

False

Attributes:

Name Type Description
blocks ModuleList

Stack of FNO transformer blocks.

Examples:

>>> model = FNOTransformer(
...     hidden_dim=512,
...     num_layers=6,
...     modes=32,
...     max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512)
>>> output = model(inputs_embeds=x)
>>> assert output.shape == x.shape

Methods:

Name Description
build_blocks

Build transformer blocks with FNO layers.

from_config

Create model from configuration.

Source code in spectrans/models/fno_transformer.py
def __init__(
    self,
    vocab_size: int | None = None,
    hidden_dim: int = 512,
    num_layers: int = 6,
    max_sequence_length: int = 1024,
    modes: int = 32,
    mlp_ratio: float = 2.0,
    use_2d: bool = False,
    spatial_dim: int | None = None,
    num_classes: int | None = None,
    ffn_hidden_dim: int | None = None,
    dropout: float = 0.0,
    use_positional_encoding: bool = True,
    positional_encoding_type: PositionalEncodingType = "sinusoidal",
    gradient_checkpointing: bool = False,
):
    # Store FNO-specific parameters
    self.modes = modes
    self.mlp_ratio = mlp_ratio
    self.use_2d = use_2d
    self.spatial_dim = spatial_dim
    self.dropout_rate = dropout

    # Validate 2D configuration
    if use_2d and spatial_dim is None:
        raise ValueError("spatial_dim must be specified when use_2d=True")
    if use_2d and spatial_dim is not None and spatial_dim * spatial_dim != max_sequence_length:
        raise ValueError(
            f"For 2D FNO, max_sequence_length ({max_sequence_length}) "
            f"must equal spatial_dim² ({spatial_dim}² = {spatial_dim * spatial_dim})"
        )

    super().__init__(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_sequence_length=max_sequence_length,
        num_classes=num_classes,
        ffn_hidden_dim=ffn_hidden_dim,
        dropout=dropout,
        use_positional_encoding=use_positional_encoding,
        positional_encoding_type=positional_encoding_type,
        gradient_checkpointing=gradient_checkpointing,
    )
Functions
build_blocks
build_blocks() -> ModuleList

Build transformer blocks with FNO layers.

Returns:

Type Description
ModuleList

List of FNO transformer blocks.

Source code in spectrans/models/fno_transformer.py
def build_blocks(self) -> nn.ModuleList:
    """Build transformer blocks with FNO layers.

    Returns
    -------
    nn.ModuleList
        List of FNO transformer blocks.
    """
    blocks = []
    for _ in range(self.num_layers):
        # Create FNO block with appropriate configuration
        fno_block = FNOBlock(
            hidden_dim=self.hidden_dim,
            modes=self.modes,
            mlp_ratio=self.mlp_ratio,
            dropout=self.dropout_rate,
        )
        blocks.append(fno_block)

    return nn.ModuleList(blocks)
from_config classmethod
from_config(config: FNOTransformerConfig) -> FNOTransformer

Create model from configuration.

Parameters:

Name Type Description Default
config FNOTransformerConfig

Model configuration object.

required

Returns:

Type Description
FNOTransformer

Instantiated model.

Source code in spectrans/models/fno_transformer.py
@classmethod
def from_config(cls, config: "FNOTransformerConfig") -> "FNOTransformer":  # type: ignore[override]
    """Create model from configuration.

    Parameters
    ----------
    config : FNOTransformerConfig
        Model configuration object.

    Returns
    -------
    FNOTransformer
        Instantiated model.
    """
    return cls(
        vocab_size=config.vocab_size,
        hidden_dim=config.hidden_dim,
        num_layers=config.num_layers,
        max_sequence_length=config.sequence_length,
        modes=config.modes,
        mlp_ratio=config.mlp_ratio,
        use_2d=config.use_2d,
        spatial_dim=config.spatial_dim,
        num_classes=config.num_classes,
        ffn_hidden_dim=config.ffn_hidden_dim,
        dropout=config.dropout,
        use_positional_encoding=config.use_positional_encoding,
        positional_encoding_type=config.positional_encoding_type,
        gradient_checkpointing=config.gradient_checkpointing,
    )

FNOEncoder

FNOEncoder(hidden_dim: int = 512, num_layers: int = 6, max_sequence_length: int = 1024, modes: int = 32, mlp_ratio: float = 2.0, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)

Bases: BaseModel

Encoder-only FNO model for representation learning.

This model uses stacked FNO blocks without causal masking, suitable for bidirectional encoding tasks like feature extraction and representation learning.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension size for the model.

512
num_layers int

Number of encoder blocks.

6
max_sequence_length int

Maximum sequence length.

1024
modes int

Number of Fourier modes to retain.

32
mlp_ratio float

MLP expansion ratio in FNO blocks.

2.0
ffn_hidden_dim int | None

Hidden dimension of the feedforward network.

None
dropout float

Dropout probability.

0.0
use_positional_encoding bool

Whether to use positional encoding.

True
positional_encoding_type str

Type of positional encoding.

"sinusoidal"
gradient_checkpointing bool

Whether to use gradient checkpointing.

False

Examples:

>>> encoder = FNOEncoder(
...     hidden_dim=512,
...     num_layers=6,
...     modes=32,
...     max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512)
>>> encoded = encoder(inputs_embeds=x)
>>> assert encoded.shape == x.shape

Methods:

Name Description
build_blocks

Build encoder blocks with FNO layers.

Source code in spectrans/models/fno_transformer.py
def __init__(
    self,
    hidden_dim: int = 512,
    num_layers: int = 6,
    max_sequence_length: int = 1024,
    modes: int = 32,
    mlp_ratio: float = 2.0,
    ffn_hidden_dim: int | None = None,
    dropout: float = 0.0,
    use_positional_encoding: bool = True,
    positional_encoding_type: PositionalEncodingType = "sinusoidal",
    gradient_checkpointing: bool = False,
):
    # Store FNO-specific parameters
    self.modes = modes
    self.mlp_ratio = mlp_ratio
    self.dropout_rate = dropout

    super().__init__(
        vocab_size=None,  # Encoder doesn't need vocab
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_sequence_length=max_sequence_length,
        num_classes=None,  # No classification head
        ffn_hidden_dim=ffn_hidden_dim,
        dropout=dropout,
        use_positional_encoding=use_positional_encoding,
        positional_encoding_type=positional_encoding_type,
        gradient_checkpointing=gradient_checkpointing,
    )

    # Set output type to none for encoder
    self.output_type = "none"
Functions
build_blocks
build_blocks() -> ModuleList

Build encoder blocks with FNO layers.

Returns:

Type Description
ModuleList

List of FNO encoder blocks.

Source code in spectrans/models/fno_transformer.py
def build_blocks(self) -> nn.ModuleList:
    """Build encoder blocks with FNO layers.

    Returns
    -------
    nn.ModuleList
        List of FNO encoder blocks.
    """
    blocks = []
    for _ in range(self.num_layers):
        fno_block = FNOBlock(
            hidden_dim=self.hidden_dim,
            modes=self.modes,
            mlp_ratio=self.mlp_ratio,
            dropout=self.dropout_rate,
        )
        blocks.append(fno_block)

    return nn.ModuleList(blocks)

FNODecoder

FNODecoder(vocab_size: int, hidden_dim: int = 512, num_layers: int = 12, max_sequence_length: int = 2048, modes: int = 32, mlp_ratio: float = 2.0, causal: bool = True, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)

Bases: BaseModel

Decoder FNO model for generation tasks.

This model uses causal FNO blocks suitable for autoregressive generation tasks. The spectral operations are modified to respect causality.

Parameters:

Name Type Description Default
vocab_size int

Size of the vocabulary for generation.

required
hidden_dim int

Hidden dimension size.

512
num_layers int

Number of decoder blocks.

12
max_sequence_length int

Maximum sequence length.

2048
modes int

Number of Fourier modes (adjusted for causality).

32
mlp_ratio float

MLP expansion ratio.

2.0
causal bool

Whether to use causal masking.

True
ffn_hidden_dim int | None

Hidden dimension of the feedforward network.

None
dropout float

Dropout probability.

0.0
use_positional_encoding bool

Whether to use positional encoding.

True
positional_encoding_type str

Type of positional encoding.

"sinusoidal"
gradient_checkpointing bool

Whether to use gradient checkpointing.

False

Examples:

>>> decoder = FNODecoder(
...     vocab_size=10000,
...     hidden_dim=512,
...     num_layers=12,
...     modes=32,
...     causal=True,
...     max_sequence_length=2048
... )
>>> input_ids = torch.randint(0, 10000, (32, 100))
>>> logits = decoder(input_ids)
>>> assert logits.shape == (32, 100, 10000)

Methods:

Name Description
build_blocks

Build decoder blocks with causal FNO layers.

forward

Forward pass through the decoder.

Source code in spectrans/models/fno_transformer.py
def __init__(
    self,
    vocab_size: int,
    hidden_dim: int = 512,
    num_layers: int = 12,
    max_sequence_length: int = 2048,
    modes: int = 32,
    mlp_ratio: float = 2.0,
    causal: bool = True,
    ffn_hidden_dim: int | None = None,
    dropout: float = 0.0,
    use_positional_encoding: bool = True,
    positional_encoding_type: PositionalEncodingType = "sinusoidal",
    gradient_checkpointing: bool = False,
):
    # Store FNO-specific parameters
    self.modes = modes
    self.mlp_ratio = mlp_ratio
    self.causal = causal
    self.dropout_rate = dropout

    super().__init__(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_sequence_length=max_sequence_length,
        num_classes=None,  # Decoder uses LM head instead
        ffn_hidden_dim=ffn_hidden_dim,
        dropout=dropout,
        use_positional_encoding=use_positional_encoding,
        positional_encoding_type=positional_encoding_type,
        gradient_checkpointing=gradient_checkpointing,
    )

    # Add language modeling head
    self.lm_head = nn.Linear(hidden_dim, vocab_size)
    self.output_type = "lm"
Functions
build_blocks
build_blocks() -> ModuleList

Build decoder blocks with causal FNO layers.

Returns:

Type Description
ModuleList

List of causal FNO decoder blocks.

Source code in spectrans/models/fno_transformer.py
def build_blocks(self) -> nn.ModuleList:
    """Build decoder blocks with causal FNO layers.

    Returns
    -------
    nn.ModuleList
        List of causal FNO decoder blocks.
    """
    blocks = []
    for _ in range(self.num_layers):
        # Create FNO block
        # Note: Causality in spectral domain requires special handling
        # This is a simplified version - full causality would need custom implementation
        fno_block = FNOBlock(
            hidden_dim=self.hidden_dim,
            modes=self.modes,
            mlp_ratio=self.mlp_ratio,
            dropout=self.dropout_rate,
        )
        blocks.append(fno_block)

    return nn.ModuleList(blocks)
forward
forward(input_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, attention_mask: Tensor | None = None) -> Tensor

Forward pass through the decoder.

Parameters:

Name Type Description Default
input_ids Tensor | None

Input token IDs of shape (batch_size, sequence_length).

None
inputs_embeds Tensor | None

Pre-embedded inputs of shape (batch_size, sequence_length, hidden_dim).

None
attention_mask Tensor | None

Attention mask for padding.

None

Returns:

Type Description
Tensor

Logits of shape (batch_size, sequence_length, vocab_size).

Source code in spectrans/models/fno_transformer.py
def forward(
    self,
    input_ids: torch.Tensor | None = None,
    inputs_embeds: torch.Tensor | None = None,
    attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass through the decoder.

    Parameters
    ----------
    input_ids : torch.Tensor | None, optional
        Input token IDs of shape (batch_size, sequence_length).
    inputs_embeds : torch.Tensor | None, optional
        Pre-embedded inputs of shape (batch_size, sequence_length, hidden_dim).
    attention_mask : torch.Tensor | None, optional
        Attention mask for padding.

    Returns
    -------
    torch.Tensor
        Logits of shape (batch_size, sequence_length, vocab_size).
    """
    # Use parent class forward for processing
    hidden_states = super().forward(
        input_ids=input_ids,
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
    )

    # Apply LM head
    logits = self.lm_head(hidden_states)
    return logits  # type: ignore[no-any-return]

Functions