Hybrid Models¶
spectrans.models.hybrid ¶
Hybrid transformer models combining spectral and spatial mixing strategies.
This module implements transformer architectures that alternate between different mixing mechanisms across layers. Spectral mixing layers (Fourier, Wavelet, AFNO, GFNet) provide efficient global pattern capture, while spatial mixing layers (attention variants) model local dependencies. This design balances computational efficiency with modeling capacity.
Classes:
| Name | Description |
|---|---|
HybridTransformer |
Configurable hybrid model with alternating spectral-spatial layers. |
HybridEncoder |
Encoder-only variant for representation learning tasks. |
AlternatingTransformer |
Simplified model alternating between exactly two mixing types. |
Examples:
Basic hybrid transformer with Fourier-Attention alternation:
>>> import torch
>>> from spectrans.models.hybrid import HybridTransformer
>>> model = HybridTransformer(
... vocab_size=30000,
... hidden_dim=768,
... num_layers=12,
... spectral_type='fourier',
... spatial_type='attention',
... num_classes=10
... )
>>> input_ids = torch.randint(0, 30000, (32, 512))
>>> logits = model(input_ids)
>>> assert logits.shape == (32, 10)
Wavelet-SpectralAttention hybrid:
>>> model = HybridTransformer(
... hidden_dim=512,
... num_layers=8,
... spectral_type='wavelet',
... spatial_type='spectral_attention',
... spectral_config={'wavelet': 'db8', 'levels': 3},
... spatial_config={'num_features': 256}
... )
Alternating transformer with two specific layer types:
>>> model = AlternatingTransformer(
... hidden_dim=768,
... num_layers=12,
... layer1_type='fourier',
... layer2_type='attention',
... layer1_config={'use_real_fft': True},
... layer2_config={'num_heads': 8}
... )
Notes
Mathematical Foundation:
Hybrid transformers alternate between spectral and spatial mixing strategies across layers. For the default "even_spectral" pattern with \(L\) layers:
Even-indexed layers (spectral mixing):
Odd-indexed layers (spatial mixing):
where each is followed by a feedforward network:
Spectral Mixing Operations:
Different spectral mixing types provide varying computational complexities:
- Fourier: \(\text{FFT}\) and \(\text{IFFT}\) with \(O(n \log n)\) complexity
- Wavelet: Multi-scale DWT decomposition with \(O(n)\) complexity
- AFNO: Mode-truncated spectral convolution with \(O(k_n k_d)\) complexity
- GFNet: Learnable global filters with \(O(n \log n)\) complexity
Spatial Mixing Operations:
Spatial mixing layers model position-dependent interactions:
- Standard Attention: \(\text{softmax}(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}})\mathbf{V}\) with \(O(n^2 d)\) complexity
- Spectral Attention: RFF-approximated attention with \(O(n D d)\) complexity where \(D\) is feature dimension
- LST: Linear spectral transform attention with \(O(n \log n \cdot d)\) complexity
Complexity Analysis:
For a hybrid model with \(L\) layers, \(n\) sequence length, \(d\) hidden dimension:
Total complexity depends on the dominant mixing operation. With \(L_s\) spectral and \(L_{sp}\) spatial layers:
where \(T_{\text{FFN}} = O(n d^2)\) for feedforward networks.
The hybrid approach reduces the overall complexity compared to pure attention models while maintaining modeling capacity through the complementary mixing strategies.
References
Yi Tay, Mostafa Dehghani, Dara Bahri, and Donald Metzler. 2022. Efficient transformers: A survey. ACM Computing Surveys, 55(6):1-28.
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.blocks.hybrid : Hybrid block implementations. spectrans.models.fnet : Pure Fourier transformer implementation. spectrans.models.spectral_attention : Pure spectral attention transformer.
Classes¶
StandardAttention ¶
Bases: Module
Standard multi-head self-attention wrapper.
Wraps PyTorch's MultiheadAttention for use as a mixing layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
dropout
|
float
|
Dropout probability. |
0.0
|
Methods:
| Name | Description |
|---|---|
forward |
Apply self-attention. |
Source code in spectrans/models/hybrid.py
Functions¶
forward ¶
Apply self-attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, seq_len, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of same shape. |
Source code in spectrans/models/hybrid.py
HybridTransformer ¶
HybridTransformer(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, spectral_type: str = 'fourier', spatial_type: str = 'attention', alternation_pattern: str = 'even_spectral', num_heads: int = 8, spectral_config: dict | None = None, spatial_config: dict | None = None, num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', gradient_checkpointing: bool = False)
Bases: BaseModel
Hybrid Spectral-Spatial Transformer model.
Combines spectral and spatial mixing strategies across layers to balance computational efficiency with modeling expressiveness. The model alternates between spectral layers (efficient global mixing) and spatial layers (expressive local modeling) according to configurable patterns.
For a sequence \(\mathbf{X}_0 \in \mathbb{R}^{n \times d}\), the hybrid transformer applies alternating transformations:
Spectral layers (\(\ell\) even for "even_spectral" pattern):
Spatial layers (\(\ell\) odd for "even_spectral" pattern):
where \(\text{LN}(\cdot)\) denotes LayerNorm and each block concludes with:
The spectral mixing operations provide different complexity-accuracy tradeoffs: - Fourier: \(O(n \log n)\) via FFT/IFFT - Wavelet: \(O(n)\) via fast DWT algorithms - AFNO: \(O(k_n k_d d)\) with mode truncation parameters \(k_n, k_d\) - GFNet: \(O(n \log n)\) with learnable spectral filters
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. |
None
|
hidden_dim
|
int
|
Hidden dimension size. |
768
|
num_layers
|
int
|
Number of transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
512
|
spectral_type
|
str
|
Type of spectral mixing: 'fourier', 'wavelet', 'afno', 'gfnet'. |
'fourier'
|
spatial_type
|
str
|
Type of spatial mixing: 'attention', 'spectral_attention', 'lst'. |
'attention'
|
alternation_pattern
|
str
|
How to alternate: 'even_spectral', 'alternate', 'custom'. |
'even_spectral'
|
num_heads
|
int
|
Number of attention heads for spatial layers. |
8
|
spectral_config
|
dict | None
|
Additional configuration for spectral layers. |
None
|
spatial_config
|
dict | None
|
Additional configuration for spatial layers. |
None
|
num_classes
|
int | None
|
Number of output classes for classification. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. |
None
|
norm_eps
|
float
|
Layer normalization epsilon. |
1e-12
|
output_type
|
OutputHeadType
|
Type of output head. |
'classification'
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
spectral_type |
str
|
Type of spectral mixing being used. |
spatial_type |
str
|
Type of spatial mixing being used. |
alternation_pattern |
str
|
The alternation pattern. |
blocks |
ModuleList
|
List of hybrid transformer blocks. |
Methods:
| Name | Description |
|---|---|
build_blocks |
Build hybrid transformer blocks. |
from_config |
Create hybrid transformer from configuration. |
Source code in spectrans/models/hybrid.py
Functions¶
build_blocks ¶
Build hybrid transformer blocks.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of transformer blocks with alternating mixing strategies. |
Source code in spectrans/models/hybrid.py
from_config
classmethod
¶
from_config(config: HybridModelConfig) -> HybridTransformer
Create hybrid transformer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
HybridModelConfig
|
Configuration object with model parameters. |
required |
Returns:
| Type | Description |
|---|---|
HybridTransformer
|
Configured hybrid transformer model. |
Source code in spectrans/models/hybrid.py
HybridEncoder ¶
HybridEncoder(hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, spectral_type: str = 'fourier', spatial_type: str = 'attention', alternation_pattern: str = 'even_spectral', num_heads: int = 8, spectral_config: dict | None = None, spatial_config: dict | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, gradient_checkpointing: bool = False)
Bases: HybridTransformer
Encoder-only hybrid transformer for representation learning.
This variant returns hidden states without any task-specific head, suitable for feature extraction and representation learning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. |
768
|
num_layers
|
int
|
Number of transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
512
|
spectral_type
|
str
|
Type of spectral mixing. |
'fourier'
|
spatial_type
|
str
|
Type of spatial mixing. |
'attention'
|
alternation_pattern
|
str
|
Layer alternation pattern. |
'even_spectral'
|
num_heads
|
int
|
Number of attention heads. |
8
|
spectral_config
|
dict | None
|
Spectral layer configuration. |
None
|
spatial_config
|
dict | None
|
Spatial layer configuration. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. |
None
|
norm_eps
|
float
|
Layer normalization epsilon. |
1e-12
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|
Source code in spectrans/models/hybrid.py
AlternatingTransformer ¶
AlternatingTransformer(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, layer1_type: str = 'fourier', layer2_type: str = 'attention', layer1_config: dict | None = None, layer2_config: dict | None = None, num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', gradient_checkpointing: bool = False)
Bases: BaseModel
Transformer that strictly alternates between two mixing strategies.
A simplified hybrid model that alternates between exactly two types of mixing layers following a strict pattern: layer1_type for even-indexed layers, layer2_type for odd-indexed layers. This design enables controlled comparisons between different mixing strategies.
For \(L\) layers, the alternation follows:
Each layer applies the mixing operation with residual connection:
followed by the standard feedforward block with another residual connection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. |
None
|
hidden_dim
|
int
|
Hidden dimension size. |
768
|
num_layers
|
int
|
Number of transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
512
|
layer1_type
|
str
|
Type of first mixing layer. |
'fourier'
|
layer2_type
|
str
|
Type of second mixing layer. |
'attention'
|
layer1_config
|
dict | None
|
Configuration for first layer type. |
None
|
layer2_config
|
dict | None
|
Configuration for second layer type. |
None
|
num_classes
|
int | None
|
Number of output classes. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. |
None
|
norm_eps
|
float
|
Layer normalization epsilon. |
1e-12
|
output_type
|
OutputHeadType
|
Type of output head. |
'classification'
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|
Methods:
| Name | Description |
|---|---|
build_blocks |
Build alternating transformer blocks. |
Source code in spectrans/models/hybrid.py
Functions¶
build_blocks ¶
Build alternating transformer blocks.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of alternating transformer blocks. |