FNO Transformer¶
spectrans.models.fno_transformer ¶
Fourier Neural Operator (FNO) transformer models.
This module implements transformer models based on the Fourier Neural Operator, which learns mappings between function spaces by parameterizing integral kernels in the Fourier domain. These models achieve \(O(n \log n)\) complexity through FFT operations and are particularly effective for learning solution operators.
The FNO mechanism learns integral operators by parameterizing convolution kernels in the Fourier domain, enabling efficient global interactions through spectral truncation and complex-valued weight multiplication.
Classes:
| Name | Description |
|---|---|
FNOTransformer |
Complete transformer model using Fourier Neural Operators. |
FNOEncoder |
Encoder-only model for representation learning with FNO. |
FNODecoder |
Decoder model with causal FNO support for generation tasks. |
Examples:
Basic FNO transformer:
>>> import torch
>>> from spectrans.models.fno_transformer import FNOTransformer
>>> model = FNOTransformer(
... hidden_dim=512,
... num_layers=6,
... modes=32,
... max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512) # (batch, seq_len, dim)
>>> output = model(inputs_embeds=x)
>>> assert output.shape == x.shape
Using with token inputs and classification:
>>> model = FNOTransformer(
... vocab_size=10000,
... hidden_dim=512,
... num_layers=6,
... modes=16,
... num_classes=10,
... max_sequence_length=512
... )
>>> input_ids = torch.randint(0, 10000, (32, 100))
>>> logits = model(input_ids)
>>> assert logits.shape == (32, 10)
2D FNO for image-like sequence data:
>>> from spectrans.models.fno_transformer import FNOTransformer
>>> model = FNOTransformer(
... hidden_dim=512,
... num_layers=6,
... modes=32,
... use_2d=True,
... spatial_dim=64, # Sequence viewed as 64x64 spatial grid
... max_sequence_length=4096
... )
Notes
Mathematical Foundation:
The FNO learns operators between function spaces through integral transforms:
In the Fourier domain, convolution becomes multiplication:
Where \(R_{\theta}\) are learnable complex weights truncated to the lowest \(k\) frequency modes:
The spectral convolution is computed as:
- Forward FFT: \(\hat{v} = \mathcal{F}[v]\)
- Mode truncation: Keep only lowest \(k\) modes
- Complex multiplication: \(\hat{u}_k = R_{\theta,k} \cdot \hat{v}_k\)
- Inverse FFT: \(u = \mathcal{F}^{-1}[\hat{u}]\)
This achieves \(O(n \log n)\) complexity while learning global dependencies through the spectral parameterization.
References
Zongyi Li, Nikola Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhattacharya, Andrew Stuart, and Anima Anandkumar. 2021. Fourier neural operator for parametric partial differential equations. In Proceedings of the International Conference on Learning Representations (ICLR).
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.layers.operators.fno : Core FNO layer implementations. spectrans.transforms.fourier : FFT operations used by FNO.
Classes¶
FNOTransformer ¶
FNOTransformer(vocab_size: int | None = None, hidden_dim: int = 512, num_layers: int = 6, max_sequence_length: int = 1024, modes: int = 32, mlp_ratio: float = 2.0, use_2d: bool = False, spatial_dim: int | None = None, num_classes: int | None = None, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)
Bases: BaseModel
Fourier Neural Operator transformer model.
This model uses Fourier Neural Operators for sequence mixing, achieving O(n log n) complexity through FFT operations. The model learns mappings between function spaces by parameterizing kernels in the Fourier domain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs. |
None
|
hidden_dim
|
int
|
Hidden dimension size for the model. |
512
|
num_layers
|
int
|
Number of transformer blocks. |
6
|
max_sequence_length
|
int
|
Maximum sequence length the model can process. |
1024
|
modes
|
int
|
Number of Fourier modes to retain (frequency truncation). |
32
|
mlp_ratio
|
float
|
Expansion ratio for the MLP in FNO blocks. |
2.0
|
use_2d
|
bool
|
Whether to use 2D spectral convolutions for spatial data. |
False
|
spatial_dim
|
int | None
|
Spatial dimension when using 2D convolutions (sequence = spatial_dim²). |
None
|
num_classes
|
int | None
|
Number of output classes for classification. |
None
|
ffn_hidden_dim
|
int | None
|
Hidden dimension of the feedforward network. Default is 4 * hidden_dim. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding ("sinusoidal" or "learned"). |
"sinusoidal"
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing to save memory. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
blocks |
ModuleList
|
Stack of FNO transformer blocks. |
Examples:
>>> model = FNOTransformer(
... hidden_dim=512,
... num_layers=6,
... modes=32,
... max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512)
>>> output = model(inputs_embeds=x)
>>> assert output.shape == x.shape
Methods:
| Name | Description |
|---|---|
build_blocks |
Build transformer blocks with FNO layers. |
from_config |
Create model from configuration. |
Source code in spectrans/models/fno_transformer.py
Functions¶
build_blocks ¶
Build transformer blocks with FNO layers.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of FNO transformer blocks. |
Source code in spectrans/models/fno_transformer.py
from_config
classmethod
¶
from_config(config: FNOTransformerConfig) -> FNOTransformer
Create model from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FNOTransformerConfig
|
Model configuration object. |
required |
Returns:
| Type | Description |
|---|---|
FNOTransformer
|
Instantiated model. |
Source code in spectrans/models/fno_transformer.py
FNOEncoder ¶
FNOEncoder(hidden_dim: int = 512, num_layers: int = 6, max_sequence_length: int = 1024, modes: int = 32, mlp_ratio: float = 2.0, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)
Bases: BaseModel
Encoder-only FNO model for representation learning.
This model uses stacked FNO blocks without causal masking, suitable for bidirectional encoding tasks like feature extraction and representation learning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size for the model. |
512
|
num_layers
|
int
|
Number of encoder blocks. |
6
|
max_sequence_length
|
int
|
Maximum sequence length. |
1024
|
modes
|
int
|
Number of Fourier modes to retain. |
32
|
mlp_ratio
|
float
|
MLP expansion ratio in FNO blocks. |
2.0
|
ffn_hidden_dim
|
int | None
|
Hidden dimension of the feedforward network. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding. |
"sinusoidal"
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|
Examples:
>>> encoder = FNOEncoder(
... hidden_dim=512,
... num_layers=6,
... modes=32,
... max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512)
>>> encoded = encoder(inputs_embeds=x)
>>> assert encoded.shape == x.shape
Methods:
| Name | Description |
|---|---|
build_blocks |
Build encoder blocks with FNO layers. |
Source code in spectrans/models/fno_transformer.py
Functions¶
build_blocks ¶
Build encoder blocks with FNO layers.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of FNO encoder blocks. |
Source code in spectrans/models/fno_transformer.py
FNODecoder ¶
FNODecoder(vocab_size: int, hidden_dim: int = 512, num_layers: int = 12, max_sequence_length: int = 2048, modes: int = 32, mlp_ratio: float = 2.0, causal: bool = True, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)
Bases: BaseModel
Decoder FNO model for generation tasks.
This model uses causal FNO blocks suitable for autoregressive generation tasks. The spectral operations are modified to respect causality.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int
|
Size of the vocabulary for generation. |
required |
hidden_dim
|
int
|
Hidden dimension size. |
512
|
num_layers
|
int
|
Number of decoder blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
2048
|
modes
|
int
|
Number of Fourier modes (adjusted for causality). |
32
|
mlp_ratio
|
float
|
MLP expansion ratio. |
2.0
|
causal
|
bool
|
Whether to use causal masking. |
True
|
ffn_hidden_dim
|
int | None
|
Hidden dimension of the feedforward network. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding. |
"sinusoidal"
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|
Examples:
>>> decoder = FNODecoder(
... vocab_size=10000,
... hidden_dim=512,
... num_layers=12,
... modes=32,
... causal=True,
... max_sequence_length=2048
... )
>>> input_ids = torch.randint(0, 10000, (32, 100))
>>> logits = decoder(input_ids)
>>> assert logits.shape == (32, 100, 10000)
Methods:
| Name | Description |
|---|---|
build_blocks |
Build decoder blocks with causal FNO layers. |
forward |
Forward pass through the decoder. |
Source code in spectrans/models/fno_transformer.py
Functions¶
build_blocks ¶
Build decoder blocks with causal FNO layers.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of causal FNO decoder blocks. |
Source code in spectrans/models/fno_transformer.py
forward ¶
forward(input_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, attention_mask: Tensor | None = None) -> Tensor
Forward pass through the decoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Input token IDs of shape (batch_size, sequence_length). |
None
|
inputs_embeds
|
Tensor | None
|
Pre-embedded inputs of shape (batch_size, sequence_length, hidden_dim). |
None
|
attention_mask
|
Tensor | None
|
Attention mask for padding. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Logits of shape (batch_size, sequence_length, vocab_size). |