FNet Models¶
spectrans.models.fnet ¶
FNet: Mixing Tokens with Fourier Transforms.
This module implements the FNet architecture, which replaces the self-attention mechanism in transformers with Fourier transform-based token mixing. FNet maintains \(O(n \log n)\) computational complexity compared to \(O(n^2)\) for standard attention mechanisms.
The architecture uses 2D Discrete Fourier Transforms (DFT) to mix tokens, enabling global information mixing across the sequence with reduced computational cost compared to attention-based models.
Classes:
| Name | Description |
|---|---|
FNet |
Complete FNet model with Fourier mixing layers. |
FNetEncoder |
Encoder-only FNet for representation learning. |
Examples:
Basic FNet usage for classification:
>>> from spectrans.models.fnet import FNet
>>> model = FNet(
... vocab_size=30000,
... hidden_dim=768,
... num_layers=12,
... max_sequence_length=512,
... num_classes=2
... )
>>> input_ids = torch.randint(0, 30000, (8, 512))
>>> logits = model(input_ids=input_ids)
>>> assert logits.shape == (8, 2)
Using FNet encoder for feature extraction:
>>> from spectrans.models.fnet import FNetEncoder
>>> encoder = FNetEncoder(hidden_dim=768, num_layers=12)
>>> inputs = torch.randn(8, 512, 768)
>>> features = encoder(inputs_embeds=inputs)
>>> assert features.shape == (8, 512, 768)
Creating from configuration:
>>> from spectrans.config.models import FNetModelConfig
>>> config = FNetModelConfig(
... hidden_dim=768,
... num_layers=12,
... sequence_length=512
... )
>>> model = FNet.from_config(config)
Notes
Mathematical Foundation ~~~~~~~~~~~~~~~~~~~~~~~
Given input tensor \(\mathbf{X} \in \mathbb{R}^{n \times d}\) where \(n\) is sequence length and \(d\) is hidden dimension, FNet applies the following operations in each layer \(l\):
Fourier Mixing Operation:
The core mixing operation is defined as:
where:
- \(\mathcal{F}_n\) denotes 1D DFT along the sequence dimension
- \(\mathcal{F}_d^{-1}\) denotes inverse 1D DFT along the feature dimension
- \(\Re(\cdot)\) takes the real part of complex values
Complete Layer Operations:
For each FNet layer \(l\), the computation proceeds as:
where the feedforward network (FFN) is:
with \(\mathbf{W}_1 \in \mathbb{R}^{4d \times d}\), \(\mathbf{b}_1 \in \mathbb{R}^{4d}\), \(\mathbf{W}_2 \in \mathbb{R}^{d \times 4d}\), \(\mathbf{b}_2 \in \mathbb{R}^d\).
Complexity Analysis:
- Time Complexity: \(O(L \cdot n \log n \cdot d)\) where \(L\) is the number of layers
- Space Complexity: \(O(L \cdot n \cdot d)\)
- No learned parameters in the mixing operation (only in FFN and embeddings)
References
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
See Also
spectrans.layers.mixing.fourier : Fourier mixing layer implementation. spectrans.models.base : Base model classes.
Classes¶
FNet ¶
FNet(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', use_real_fft: bool = True, gradient_checkpointing: bool = False)
Bases: BaseModel
FNet model with Fourier transform-based token mixing.
FNet replaces the self-attention mechanism with Fourier transforms, achieving \(O(n \log n)\) complexity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs. |
None
|
hidden_dim
|
int
|
Hidden dimension size. Default is 768. |
768
|
num_layers
|
int
|
Number of FNet layers. Default is 12. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. Default is 512. |
512
|
num_classes
|
int | None
|
Number of output classes for classification. Default is None. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. Default is True. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding: 'sinusoidal' or 'learned'. Default is 'sinusoidal'. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. Default is 0.1. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. Default is 1e-12. |
1e-12
|
output_type
|
str
|
Type of output head: 'classification', 'regression', 'sequence', or 'none'. Default is 'classification'. |
'classification'
|
use_real_fft
|
bool
|
Whether to use real FFT for efficiency. Default is True. |
True
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Default is False. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
use_real_fft |
bool
|
Whether real FFT is used for efficiency. |
blocks |
ModuleList
|
List of FNet transformer blocks. |
Methods:
| Name | Description |
|---|---|
build_blocks |
Build FNet transformer blocks with Fourier mixing. |
from_config |
Create FNet model from configuration. |
Source code in spectrans/models/fnet.py
Functions¶
build_blocks ¶
Build FNet transformer blocks with Fourier mixing.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of FNet transformer blocks. |
Source code in spectrans/models/fnet.py
from_config
classmethod
¶
from_config(config: FNetModelConfig) -> FNet
Create FNet model from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FNetModelConfig
|
Configuration object with model parameters. |
required |
Returns:
| Type | Description |
|---|---|
FNet
|
Configured FNet model. |
Source code in spectrans/models/fnet.py
FNetEncoder ¶
FNetEncoder(hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, use_real_fft: bool = True, gradient_checkpointing: bool = False)
Bases: FNet
Encoder-only FNet model for representation learning.
This variant of FNet is designed for tasks that require extracting representations rather than making predictions. It returns the hidden states from the final layer without any task-specific head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. Default is 768. |
768
|
num_layers
|
int
|
Number of FNet layers. Default is 12. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. Default is 512. |
512
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. Default is True. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding. Default is 'sinusoidal'. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. Default is 0.1. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. Default is 1e-12. |
1e-12
|
use_real_fft
|
bool
|
Whether to use real FFT. Default is True. |
True
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Default is False. |
False
|