GFNet Models¶
spectrans.models.gfnet ¶
Global Filter Networks (GFNet) for efficient spectral transformers.
This module implements the Global Filter Network architecture, which uses learnable complex-valued filters in the frequency domain for token mixing. GFNet provides a learnable alternative to FNet while maintaining \(O(n \log n)\) complexity.
The architecture applies learnable filters in the Fourier domain, enabling the model to selectively emphasize or suppress different frequency components while maintaining computational efficiency compared to attention mechanisms.
Classes:
| Name | Description |
|---|---|
GFNet |
Complete GFNet model with global filter mixing layers. |
GFNetEncoder |
Encoder-only GFNet for representation learning. |
Examples:
Basic GFNet usage for classification:
>>> from spectrans.models.gfnet import GFNet
>>> model = GFNet(
... vocab_size=30000,
... hidden_dim=768,
... num_layers=12,
... max_sequence_length=512,
... num_classes=2
... )
>>> input_ids = torch.randint(0, 30000, (8, 512))
>>> logits = model(input_ids=input_ids)
>>> assert logits.shape == (8, 2)
Using GFNet encoder:
>>> from spectrans.models.gfnet import GFNetEncoder
>>> encoder = GFNetEncoder(
... hidden_dim=768,
... num_layers=12,
... max_sequence_length=512
... )
>>> inputs = torch.randn(8, 512, 768)
>>> features = encoder(inputs_embeds=inputs)
>>> assert features.shape == (8, 512, 768)
Creating from configuration:
>>> from spectrans.config.models import GFNetModelConfig
>>> config = GFNetModelConfig(
... hidden_dim=768,
... num_layers=12,
... sequence_length=512
... )
>>> model = GFNet.from_config(config)
Notes
Mathematical Foundation ~~~~~~~~~~~~~~~~~~~~~~~
Given input tensor \(\mathbf{X} \in \mathbb{R}^{n \times d}\) where \(n\) is sequence length and \(d\) is hidden dimension, GFNet applies learnable complex filters in the frequency domain.
Global Filter Operation:
The core filtering operation is defined as:
where:
- \(\mathbf{H} \in \mathbb{C}^{n \times d}\) is a learnable complex-valued filter
- \(\odot\) denotes element-wise (Hadamard) multiplication
- \(\mathcal{F}\) and \(\mathcal{F}^{-1}\) are FFT and IFFT along sequence dimension
Filter Parameterization:
The learnable filter \(\mathbf{H}\) is parameterized as:
where:
- \(\mathbf{W}_r, \mathbf{W}_i \in \mathbb{R}^{n \times d}\) are real-valued learnable parameters
- \(\sigma\) is an activation function (typically sigmoid or tanh)
- \(i\) is the imaginary unit
Complete Layer Operations:
For each GFNet layer \(l\), the computation proceeds as:
where FFN follows the same structure as in FNet.
Complexity Analysis:
- Time Complexity: \(O(L \cdot n \log n \cdot d)\) where \(L\) is the number of layers
- Space Complexity: \(O(L \cdot n \cdot d)\)
- Learnable Parameters: \(O(2nd)\) for the complex filter per layer
References
Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, and Jie Zhou. 2021. Global filter networks for image classification. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021), pages 980-993.
See Also
spectrans.layers.mixing.global_filter : Global filter mixing layer implementation. spectrans.models.base : Base model classes.
Classes¶
GFNet ¶
GFNet(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', filter_activation: FilterActivationType = 'sigmoid', gradient_checkpointing: bool = False)
Bases: BaseModel
Global Filter Network model with learnable frequency domain filters.
GFNet uses learnable complex filters in the Fourier domain for token mixing, maintaining \(O(n \log n)\) complexity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs. |
None
|
hidden_dim
|
int
|
Hidden dimension size. Default is 768. |
768
|
num_layers
|
int
|
Number of GFNet layers. Default is 12. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. Default is 512. |
512
|
num_classes
|
int | None
|
Number of output classes for classification. Default is None. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. Default is True. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding: 'sinusoidal' or 'learned'. Default is 'sinusoidal'. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. Default is 0.1. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. Default is 1e-12. |
1e-12
|
output_type
|
str
|
Type of output head: 'classification', 'regression', 'sequence', or 'none'. Default is 'classification'. |
'classification'
|
filter_activation
|
str
|
Activation function for filters: 'sigmoid' or 'tanh'. Default is 'sigmoid'. |
'sigmoid'
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Default is False. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_activation |
str
|
Activation function used for filters. |
blocks |
ModuleList
|
List of GFNet transformer blocks. |
Methods:
| Name | Description |
|---|---|
build_blocks |
Build GFNet transformer blocks with global filter mixing. |
from_config |
Create GFNet model from configuration. |
Source code in spectrans/models/gfnet.py
Functions¶
build_blocks ¶
Build GFNet transformer blocks with global filter mixing.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of GFNet transformer blocks. |
Source code in spectrans/models/gfnet.py
from_config
classmethod
¶
from_config(config: GFNetModelConfig) -> GFNet
Create GFNet model from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GFNetModelConfig
|
Configuration object with model parameters. |
required |
Returns:
| Type | Description |
|---|---|
GFNet
|
Configured GFNet model. |
Source code in spectrans/models/gfnet.py
GFNetEncoder ¶
GFNetEncoder(hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, filter_activation: FilterActivationType = 'sigmoid', gradient_checkpointing: bool = False)
Bases: GFNet
Encoder-only GFNet model for representation learning.
This variant of GFNet is designed for tasks that require extracting representations rather than making predictions. It returns the hidden states from the final layer without any task-specific head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. Default is 768. |
768
|
num_layers
|
int
|
Number of GFNet layers. Default is 12. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. Default is 512. |
512
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. Default is True. |
True
|
positional_encoding_type
|
str
|
Type of positional encoding. Default is 'sinusoidal'. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. Default is 0.1. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. Default is 1e-12. |
1e-12
|
filter_activation
|
str
|
Activation function for filters. Default is 'sigmoid'. |
'sigmoid'
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Default is False. |
False
|