Layer Implementations¶
spectrans.layers ¶
Layer implementations for spectral transformers.
Provides spectral transformer layers that replace traditional attention mechanisms with spectral operations. The layers are organized into three categories: mixing layers, attention layers, and neural operators for different use cases with standard transformer architecture compatibility.
Modules:
| Name | Description |
|---|---|
attention |
Spectral attention mechanisms with linear complexity. |
mixing |
Token mixing layers using spectral transforms. |
operators |
Fourier neural operators for function space learning. |
Classes:
| Name | Description |
|---|---|
AdaptiveGlobalFilter |
Enhanced global filter with adaptive initialization. |
AFNOMixing |
Adaptive Fourier Neural Operator with mode truncation. |
DCTAttention |
Specialized LST attention using discrete cosine transform. |
FilterMixingLayer |
Base class for learnable frequency domain filters. |
FNOBlock |
FNO block with spectral convolution and feedforward. |
FourierMixing |
2D FFT mixing for both sequence and feature dimensions (FNet). |
FourierMixing1D |
1D FFT mixing along sequence dimension only. |
FourierNeuralOperator |
Base FNO layer for learning operators in function spaces. |
GlobalFilterMixing |
Learnable complex filters in frequency domain (GFNet). |
GlobalFilterMixing2D |
2D variant with filtering in both dimensions. |
HadamardAttention |
Fast attention using Hadamard transform operations. |
KernelAttention |
General kernel-based attention with various kernel options. |
LSTAttention |
Linear Spectral Transform attention with configurable transforms. |
MixedSpectralAttention |
Multi-transform attention combining multiple spectral methods. |
MixingLayer |
Base class for spectral mixing operations. |
PerformerAttention |
Performer-style attention with FAVOR+ algorithm. |
RealFourierMixing |
Memory-efficient real FFT variant for real-valued inputs. |
SeparableFourierMixing |
Configurable sequence and/or feature mixing. |
SpectralAttention |
Multi-head spectral attention using random Fourier features. |
SpectralConv1d |
1D spectral convolution operator for sequence data. |
SpectralConv2d |
2D spectral convolution operator for image-like data. |
UnitaryMixingLayer |
Base class for energy-preserving mixing transforms. |
WaveletMixing |
1D wavelet mixing using discrete wavelet transform. |
WaveletMixing2D |
2D wavelet mixing for spatial data processing. |
Examples:
Basic Fourier mixing layer (FNet-style):
>>> import torch
>>> from spectrans.layers import FourierMixing
>>>
>>> # Create Fourier mixing layer
>>> mixer = FourierMixing(hidden_dim=768)
>>> x = torch.randn(32, 512, 768) # (batch, sequence, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Global filter mixing with learnable parameters:
>>> from spectrans.layers import GlobalFilterMixing
>>>
>>> # Create global filter with learnable complex weights
>>> filter_layer = GlobalFilterMixing(
... hidden_dim=512,
... sequence_length=1024,
... activation='sigmoid'
... )
>>> x = torch.randn(16, 1024, 512)
>>> output = filter_layer(x)
Spectral attention with random Fourier features:
>>> from spectrans.layers import SpectralAttention
>>>
>>> # Create spectral attention layer
>>> attention = SpectralAttention(
... hidden_dim=768,
... num_heads=12,
... num_features=256
... )
>>> x = torch.randn(8, 256, 768)
>>> output = attention(x)
Notes
Layer Categories and Complexity:
Mixing layers have \(O(n \log n)\) or \(O(n)\) complexity. Parameter-free variants use FFT operations, while learnable filters like global filters and AFNO include trainable parameters. Multiresolution approaches use wavelet transforms for hierarchical processing.
Attention layers achieve linear \(O(n)\) complexity through kernel approximation with Random Fourier Features and orthogonal features, transform-based methods using DCT, DST, and Hadamard transforms, or hybrid approaches combining multiple transforms with learnable mixing.
Neural operators have \(O(k \cdot d^2 + n \log n)\) complexity where \(k\) is the number of modes and \(d\) is the dimension. These operators map between infinite-dimensional function spaces with resolution-invariant learning independent of discretization through spectral parameterization in the Fourier domain.
All layers use the convolution theorem for global mixing:
This replaces quadratic attention \(O(n^2)\) with logarithmic or linear complexity spectral operations.
References
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, and Jie Zhou. 2021. Global filter networks for image classification. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021), pages 980-993.
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
Zongyi Li, Nikola Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhattacharya, Andrew Stuart, and Anima Anandkumar. 2021. Fourier neural operator for parametric partial differential equations. In Proceedings of the International Conference on Learning Representations (ICLR).
Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, and Adrian Weller. 2021. Rethinking attention with performers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.transforms : Underlying spectral transform implementations.
spectrans.models : Model implementations using these layers.
spectrans.blocks : Transformer blocks that compose these layers.
Classes¶
DCTAttention ¶
DCTAttention(hidden_dim: int, num_heads: int = 8, dct_type: int = 2, learnable_scale: bool = True, dropout: float = 0.0)
Bases: LSTAttention
Attention using Discrete Cosine Transform.
Specialized LST attention that uses DCT for all heads for real-valued signals with energy compaction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
dct_type
|
int
|
DCT type (2 is most common). |
2
|
learnable_scale
|
bool
|
Whether to use learnable scaling. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
Source code in spectrans/layers/attention/lst.py
HadamardAttention ¶
HadamardAttention(hidden_dim: int, num_heads: int = 8, scale_by_sqrt: bool = True, learnable_scale: bool = True, dropout: float = 0.0)
Bases: LSTAttention
Attention using fast Hadamard transform.
Uses Hadamard transform for \(O(n \log n)\) attention computation with binary coefficients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
scale_by_sqrt
|
bool
|
Whether to scale by sqrt(n) for orthogonality. |
True
|
learnable_scale
|
bool
|
Whether to use learnable diagonal scaling. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
Source code in spectrans/layers/attention/lst.py
KernelAttention ¶
KernelAttention(hidden_dim: int, num_heads: int = 8, kernel_type: Literal['gaussian', 'polynomial', 'spectral'] = 'gaussian', rank: int | None = None, num_features: int | None = None, dropout: float = 0.0)
Bases: AttentionLayer
General kernel-based attention with various kernel options.
Supports multiple kernel types including Gaussian, polynomial, and learnable spectral kernels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of heads. |
8
|
kernel_type
|
Literal['gaussian', 'polynomial', 'spectral']
|
Type of kernel to use. |
"gaussian"
|
rank
|
int | None
|
Rank for low-rank approximations. |
None
|
num_features
|
int | None
|
Number of features for RFF kernels. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
kernel_type |
str
|
Type of kernel being used. |
rank |
int | None
|
Rank for approximations. |
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of kernel attention. |
Source code in spectrans/layers/attention/spectral.py
Functions¶
forward ¶
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of kernel attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input of shape (batch_size, seq_len, hidden_dim). |
required |
mask
|
Tensor | None
|
Attention mask. |
None
|
return_attention
|
bool
|
Whether to return attention weights. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output and optionally attention weights. |
Source code in spectrans/layers/attention/spectral.py
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 | |
LSTAttention ¶
LSTAttention(hidden_dim: int, num_heads: int = 8, transform_type: Literal['dct', 'dst', 'hadamard', 'mixed'] = 'dct', learnable_scale: bool = True, normalize: bool = True, dropout: float = 0.0, use_bias: bool = True)
Bases: AttentionLayer
Linear Spectral Transform attention mechanism.
Implements attention using orthogonal transforms (DCT, DST, Hadamard) with learnable diagonal scaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
transform_type
|
Literal['dct', 'dst', 'hadamard', 'mixed']
|
Type of transform to use. "mixed" uses different transforms per head. |
"dct"
|
learnable_scale
|
bool
|
Whether to use learnable diagonal scaling matrix. |
True
|
normalize
|
bool
|
Whether to normalize in transform domain. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_bias
|
bool
|
Whether to use bias in projections. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
head_dim |
int
|
Dimension per attention head. |
transform_type |
str
|
Type of transform being used. |
transforms |
ModuleList
|
List of transforms (one per head if mixed). |
scale |
Parameter | None
|
Learnable diagonal scaling if enabled. |
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of LST attention. |
Source code in spectrans/layers/attention/lst.py
Functions¶
forward ¶
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of LST attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, seq_len, hidden_dim). |
required |
mask
|
Tensor | None
|
Attention mask of shape (batch_size, seq_len). |
None
|
return_attention
|
bool
|
Whether to return attention weights (not supported). |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, returns (output, None). |
Source code in spectrans/layers/attention/lst.py
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | |
MixedSpectralAttention ¶
MixedSpectralAttention(hidden_dim: int, num_heads: int = 9, use_fft: bool = True, use_dct: bool = True, use_hadamard: bool = True, dropout: float = 0.0)
Bases: AttentionLayer
Mixed spectral attention using multiple transform types.
Combines different spectral transforms across heads for diverse frequency representations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads (should be divisible by 3 for even split). |
8
|
use_fft
|
bool
|
Whether to include FFT heads. |
True
|
use_dct
|
bool
|
Whether to include DCT heads. |
True
|
use_hadamard
|
bool
|
Whether to include Hadamard heads. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of mixed spectral attention. |
Source code in spectrans/layers/attention/lst.py
Functions¶
forward ¶
forward(x: Tensor, _mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of mixed spectral attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input of shape (batch_size, seq_len, hidden_dim). |
required |
_mask
|
Tensor | None
|
Attention mask (not implemented for spectral attention). |
None
|
return_attention
|
bool
|
Whether to return attention weights. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output and optionally None for weights. |
Source code in spectrans/layers/attention/lst.py
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 | |
PerformerAttention ¶
PerformerAttention(hidden_dim: int, num_heads: int = 8, num_features: int | None = None, generalized: bool = False, dropout: float = 0.0)
Bases: SpectralAttention
Performer-style attention with FAVOR+ algorithm.
Implements the Performer architecture with positive orthogonal random features (FAVOR+) for softmax kernel approximation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
num_features
|
int | None
|
Number of random features. |
None
|
generalized
|
bool
|
Whether to use generalized attention (without softmax). |
False
|
dropout
|
float
|
Dropout probability. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
generalized |
bool
|
Whether using generalized attention. |
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of Performer attention. |
Source code in spectrans/layers/attention/spectral.py
Functions¶
forward ¶
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of Performer attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input of shape (batch_size, seq_len, hidden_dim). |
required |
mask
|
Tensor | None
|
Attention mask. |
None
|
return_attention
|
bool
|
Whether to return attention weights. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output tensor and optionally None for weights. |
Source code in spectrans/layers/attention/spectral.py
SpectralAttention ¶
SpectralAttention(hidden_dim: int, num_heads: int = 8, num_features: int | None = None, head_dim: int | None = None, kernel_type: Literal['gaussian', 'softmax'] = 'softmax', use_orthogonal: bool = True, feature_redraw: bool = False, dropout: float = 0.0, use_bias: bool = True)
Bases: AttentionLayer
Multi-head spectral attention using RFF approximation.
Implements attention using Random Fourier Features to approximate the softmax kernel with linear complexity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
num_features
|
int | None
|
Number of random features. If None, uses hidden_dim. |
None
|
head_dim
|
int | None
|
Dimension per head. If None, uses hidden_dim // num_heads. |
None
|
kernel_type
|
Literal['gaussian', 'softmax']
|
Type of kernel to approximate. |
"softmax"
|
use_orthogonal
|
bool
|
Whether to use orthogonal random features. |
True
|
feature_redraw
|
bool
|
Whether to redraw features at each forward pass. |
False
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_bias
|
bool
|
Whether to use bias in projections. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
head_dim |
int
|
Dimension per attention head. |
num_features |
int
|
Number of random features used. |
q_proj |
Linear
|
Query projection. |
k_proj |
Linear
|
Key projection. |
v_proj |
Linear
|
Value projection. |
out_proj |
Linear
|
Output projection. |
kernel |
RandomFeatureMap | KernelFunction
|
Kernel for attention approximation. |
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of spectral attention. |
Source code in spectrans/layers/attention/spectral.py
Functions¶
forward ¶
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of spectral attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, seq_len, hidden_dim). |
required |
mask
|
Tensor | None
|
Attention mask of shape (batch_size, seq_len). |
None
|
return_attention
|
bool
|
Whether to return attention weights (not supported). |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, also returns None (weights not available). |
Source code in spectrans/layers/attention/spectral.py
AdaptiveGlobalFilter ¶
AdaptiveGlobalFilter(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02, filter_regularization: float = 0.0, adaptive_initialization: bool = True, spectral_dropout_p: float = 0.0)
Bases: FilterMixingLayer
Adaptive Global Filter with regularization and smart initialization.
Enhanced version of global filtering with adaptive initialization strategies, regularization options, and improved training stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Expected sequence length. |
required |
activation
|
ActivationType
|
Filter activation function. |
"sigmoid"
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Numerical stability epsilon. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable. |
True
|
fft_norm
|
str
|
FFT normalization. |
"ortho"
|
filter_init_std
|
float
|
Filter initialization standard deviation. |
0.02
|
filter_regularization
|
float
|
L2 regularization strength for filter parameters. |
0.0
|
adaptive_initialization
|
bool
|
Whether to use frequency-aware initialization. |
True
|
spectral_dropout_p
|
float
|
Spectral dropout probability in frequency domain. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_regularization |
float
|
Regularization strength. |
adaptive_initialization |
bool
|
Whether adaptive initialization is used. |
spectral_dropout_p |
float
|
Spectral dropout probability. |
spectral_dropout |
Module
|
Spectral dropout layer. |
Methods:
| Name | Description |
|---|---|
forward |
Apply adaptive global filtering. |
get_filter_response |
Get adaptive frequency response. |
get_regularization_loss |
Compute L2 regularization loss for filter parameters. |
get_spectral_properties |
Get adaptive filter properties. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply adaptive global filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Adaptively filtered tensor. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get adaptive frequency response.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex frequency response with current parameters. |
Source code in spectrans/layers/mixing/global_filter.py
get_regularization_loss ¶
Compute L2 regularization loss for filter parameters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar regularization loss. |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get adaptive filter properties.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
Comprehensive properties including adaptive features. |
Source code in spectrans/layers/mixing/global_filter.py
AFNOMixing ¶
AFNOMixing(hidden_dim: int, max_sequence_length: int, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0)
Bases: MixingLayer
Adaptive Fourier Neural Operator mixing layer.
This layer performs efficient token mixing by applying learnable transformations in the truncated Fourier domain, significantly reducing computational cost while maintaining model expressiveness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input/output tensors. |
required |
max_sequence_length
|
int
|
Maximum sequence length the model will process. |
required |
modes_seq
|
int | None
|
Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2. |
None
|
modes_hidden
|
int | None
|
Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2. |
None
|
mlp_ratio
|
float
|
Expansion ratio for the MLP in Fourier domain. Default is 2.0. |
2.0
|
activation
|
str
|
Activation function for MLP. Default is 'gelu'. |
'gelu'
|
dropout
|
float
|
Dropout probability for MLP. Default is 0.0. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension size. |
max_sequence_length |
int
|
Maximum supported sequence length. |
modes_seq |
int
|
Number of retained Fourier modes in sequence dimension. |
modes_hidden |
int
|
Number of retained Fourier modes in hidden dimension. |
mlp_ratio |
float
|
MLP expansion ratio. |
fourier_weight |
Parameter
|
Complex-valued learnable weights for Fourier modes. |
mlp |
Sequential
|
MLP applied in Fourier domain. |
Examples:
>>> import torch
>>> layer = AFNOMixing(hidden_dim=768, max_sequence_length=512, modes_seq=128)
>>> x = torch.randn(32, 512, 768)
>>> output = layer(x)
>>> print(output.shape)
torch.Size([32, 512, 768])
Methods:
| Name | Description |
|---|---|
forward |
Apply AFNO mixing to input tensor. |
get_spectral_properties |
Get mathematical properties of AFNO operation. |
from_config |
Create AFNOMixing layer from configuration. |
Source code in spectrans/layers/mixing/afno.py
Functions¶
forward ¶
Apply AFNO mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of same shape as input. |
Source code in spectrans/layers/mixing/afno.py
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | |
get_spectral_properties ¶
Get mathematical properties of AFNO operation.
Returns:
| Type | Description |
|---|---|
dict[str, bool]
|
Mathematical properties of the transform. |
Source code in spectrans/layers/mixing/afno.py
from_config
classmethod
¶
from_config(config: AFNOMixingConfig) -> AFNOMixing
Create AFNOMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AFNOMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
AFNOMixing
|
Configured AFNO mixing layer. |
Source code in spectrans/layers/mixing/afno.py
FilterMixingLayer ¶
FilterMixingLayer(hidden_dim: int, sequence_length: int, dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True)
Bases: MixingLayer
Base class for frequency-domain filtering operations.
Filter mixing layers apply learnable filters in the frequency domain, enabling selective emphasis or suppression of frequency components for improved sequence modeling capabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
sequence_length
|
int
|
Expected sequence length for filter initialization. |
required |
dropout
|
float
|
Dropout probability for regularization. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable parameters. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
sequence_length |
int
|
Expected sequence length. |
learnable_filters |
bool
|
Whether filters are learnable. |
Methods:
| Name | Description |
|---|---|
get_spectral_properties |
Get properties specific to filtering operations. |
get_filter_response |
Get the frequency response of the current filters. |
analyze_frequency_response |
Analyze the frequency response characteristics. |
Source code in spectrans/layers/mixing/base.py
Functions¶
get_spectral_properties ¶
Get properties specific to filtering operations.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing filter-specific properties. |
Source code in spectrans/layers/mixing/base.py
get_filter_response
abstractmethod
¶
Get the frequency response of the current filters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex-valued frequency response of shape matching the filter parameters. |
Source code in spectrans/layers/mixing/base.py
analyze_frequency_response ¶
Analyze the frequency response characteristics.
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary containing: - 'magnitude': Magnitude response - 'phase': Phase response - 'group_delay': Group delay response - 'passband_energy': Energy in different frequency bands |
Source code in spectrans/layers/mixing/base.py
FourierMixing ¶
FourierMixing(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)
Bases: UnitaryMixingLayer
FNet-style Fourier mixing layer.
Implements the core FNet mixing operation using 2D Fourier transforms along both sequence and feature dimensions, providing an alternative to attention with \(O(n \log n)\) complexity.
The operation performs: 1. 2D FFT across sequence and feature dimensions 2. Optional real part extraction for final output (original FNet behavior) or keep complex values for full information preservation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
dropout
|
float
|
Dropout probability applied after the mixing operation. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations ("forward", "backward", "ortho"). |
"ortho"
|
keep_complex
|
bool
|
If True, keeps complex values from FFT. If False (default), takes only the real part as in original FNet. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
fft2d |
FFT2D
|
2D Fourier transform module. |
keep_complex |
bool
|
Whether to keep complex values or extract real part. |
Methods:
| Name | Description |
|---|---|
forward |
Apply Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of Fourier mixing. |
from_config |
Create FourierMixing layer from configuration. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor of same shape. Complex if keep_complex=True, real values only if keep_complex=False (default). |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties including energy preservation and domain information. |
Source code in spectrans/layers/mixing/fourier.py
from_config
classmethod
¶
from_config(config: FourierMixingConfig) -> FourierMixing
Create FourierMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FourierMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
FourierMixing
|
Configured Fourier mixing layer. |
Source code in spectrans/layers/mixing/fourier.py
FourierMixing1D ¶
FourierMixing1D(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)
Bases: UnitaryMixingLayer
1D Fourier mixing along sequence dimension only.
Applies Fourier transform only along the sequence dimension, preserving feature dimension locality while mixing tokens.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
dropout
|
float
|
Dropout probability applied after the mixing operation. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations. |
"ortho"
|
keep_complex
|
bool
|
If True, keeps complex values from FFT. If False (default), takes only the real part. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
fft1d |
FFT1D
|
1D Fourier transform module. |
keep_complex |
bool
|
Whether to keep complex values or extract real part. |
Methods:
| Name | Description |
|---|---|
forward |
Apply 1D Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of 1D Fourier mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply 1D Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor with Fourier transform applied along sequence dimension. Complex if keep_complex=True, real values only if keep_complex=False. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of 1D Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties specific to 1D sequence mixing. |
Source code in spectrans/layers/mixing/fourier.py
GlobalFilterMixing ¶
GlobalFilterMixing(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)
Bases: FilterMixingLayer
Global Filter Network mixing layer.
Implements the core GFNet mixing operation with learnable complex filters applied in the frequency domain along the sequence dimension.
The layer uses interpolation to adapt filters to different sequence lengths, processing variable-length inputs while preserving learned frequency patterns. This provides resolution independence compared to fixed-size filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Base sequence length for filter parameter initialization. The filters will be interpolated to match actual input sequence lengths. |
required |
activation
|
ActivationType
|
Activation function applied to filter parameters ("sigmoid", "tanh", "identity"). |
"sigmoid"
|
dropout
|
float
|
Dropout probability applied after filtering. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filter parameters are learnable (always True for this class). |
True
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
filter_init_std
|
float
|
Standard deviation for filter parameter initialization. |
0.02
|
Attributes:
| Name | Type | Description |
|---|---|---|
activation |
str
|
Activation function name. |
filter_real |
Parameter
|
Real part of complex filter parameters. |
filter_imag |
Parameter
|
Imaginary part of complex filter parameters. |
fft1d |
FFT1D
|
1D FFT transform for sequence dimension. |
activation_fn |
Module
|
Activation function module (Sigmoid, Tanh, or Identity). |
Methods:
| Name | Description |
|---|---|
forward |
Apply global filtering to input tensor. |
get_filter_response |
Get the current frequency response of the filters. |
get_spectral_properties |
Get spectral properties of global filtering. |
from_config |
Create GlobalFilterMixing layer from configuration. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply global filtering to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Filtered tensor of same shape as input. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get the current frequency response of the filters.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex-valued frequency response of shape (sequence_length, hidden_dim). |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get spectral properties of global filtering.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
Properties including filter characteristics. |
Source code in spectrans/layers/mixing/global_filter.py
from_config
classmethod
¶
from_config(config: GlobalFilterMixingConfig) -> GlobalFilterMixing
Create GlobalFilterMixing layer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GlobalFilterMixingConfig
|
Configuration object with layer parameters. |
required |
Returns:
| Type | Description |
|---|---|
GlobalFilterMixing
|
Configured global filter mixing layer. |
Source code in spectrans/layers/mixing/global_filter.py
GlobalFilterMixing2D ¶
GlobalFilterMixing2D(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)
Bases: FilterMixingLayer
2D Global Filter mixing with filtering along both dimensions.
Extends global filtering to both sequence and feature dimensions, similar to FNet's 2D FFT but with learnable filters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of input tensors. |
required |
sequence_length
|
int
|
Expected sequence length. |
required |
activation
|
ActivationType
|
Activation function for filter parameters. |
"sigmoid"
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
learnable_filters
|
bool
|
Whether filters are learnable. |
True
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
filter_init_std
|
float
|
Filter parameter initialization standard deviation. |
0.02
|
Attributes:
| Name | Type | Description |
|---|---|---|
filter_real |
Parameter
|
Real part of 2D complex filters. |
filter_imag |
Parameter
|
Imaginary part of 2D complex filters. |
fft2d |
FFT2D
|
2D FFT transform module. |
activation_fn |
Module
|
Activation function. |
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D global filtering. |
get_filter_response |
Get 2D frequency response. |
get_spectral_properties |
Get 2D filter properties. |
Source code in spectrans/layers/mixing/global_filter.py
Functions¶
forward ¶
Apply 2D global filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Filtered tensor of same shape. |
Source code in spectrans/layers/mixing/global_filter.py
get_filter_response ¶
Get 2D frequency response.
Returns:
| Type | Description |
|---|---|
Tensor
|
Complex 2D frequency response. |
Source code in spectrans/layers/mixing/global_filter.py
get_spectral_properties ¶
Get 2D filter properties.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool | int]
|
2D filtering characteristics. |
Source code in spectrans/layers/mixing/global_filter.py
MixingLayer ¶
Bases: SpectralComponent
Base class for spectral mixing operations.
Mixing layers perform token mixing operations using various spectral transforms instead of traditional attention mechanisms. This class provides spectral-specific functionality including mathematical property verification and standardized interfaces for spectral transform operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
dropout
|
float
|
Dropout probability for regularization. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability in normalization. |
1e-5
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension of the model. |
dropout |
Module
|
Dropout layer for regularization. |
norm_eps |
float
|
Epsilon for numerical stability. |
Methods:
| Name | Description |
|---|---|
get_spectral_properties |
Get mathematical properties of the spectral operation. |
verify_shape_consistency |
Verify that input and output shapes are consistent. |
compute_spectral_norm |
Compute spectral norm for analysis and regularization. |
Source code in spectrans/layers/mixing/base.py
Functions¶
get_spectral_properties
abstractmethod
¶
Get mathematical properties of the spectral operation.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing mathematical properties such as: - 'unitary': bool, whether the transform is unitary - 'real_output': bool, whether output is guaranteed real - 'frequency_domain': bool, whether operation occurs in frequency domain - 'energy_preserving': bool, whether energy is preserved |
Source code in spectrans/layers/mixing/base.py
verify_shape_consistency ¶
Verify that input and output shapes are consistent.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_tensor
|
Tensor
|
Input tensor to the mixing layer. |
required |
output_tensor
|
Tensor
|
Output tensor from the mixing layer. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if shapes are consistent, False otherwise. |
Source code in spectrans/layers/mixing/base.py
compute_spectral_norm ¶
Compute spectral norm for analysis and regularization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor to compute spectral norm for. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Spectral norm of the input tensor. |
Source code in spectrans/layers/mixing/base.py
RealFourierMixing ¶
RealFourierMixing(hidden_dim: int, use_real_fft: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')
Bases: UnitaryMixingLayer
Memory-efficient real Fourier mixing.
Uses real FFT operations to exploit Hermitian symmetry, providing ~2x memory and computational savings for real inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
use_real_fft
|
bool
|
Whether to use real FFT for efficiency. |
True
|
dropout
|
float
|
Dropout probability applied after mixing. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
Normalization mode for FFT operations. |
"ortho"
|
Attributes:
| Name | Type | Description |
|---|---|---|
use_real_fft |
bool
|
Whether real FFT is enabled. |
rfft |
RFFT
|
Real FFT transform for sequence dimension. |
rfft2d |
RFFT2D
|
Real 2D FFT transform for both dimensions. |
Methods:
| Name | Description |
|---|---|
forward |
Apply real Fourier mixing to input tensor. |
get_spectral_properties |
Get spectral properties of real Fourier mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply real Fourier mixing to input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). Should be real-valued for optimal efficiency. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor, guaranteed to be real-valued. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get spectral properties of real Fourier mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties including efficiency characteristics. |
Source code in spectrans/layers/mixing/fourier.py
SeparableFourierMixing ¶
SeparableFourierMixing(hidden_dim: int, mix_features: bool = True, mix_sequence: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')
Bases: UnitaryMixingLayer
Separable Fourier mixing with sequence and feature transforms.
Applies separate 1D Fourier transforms along sequence and feature dimensions, which can be more efficient than 2D FFT for certain tensor shapes and provides different mixing characteristics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the input tensors. |
required |
mix_features
|
bool
|
Whether to apply FFT along feature dimension. |
True
|
mix_sequence
|
bool
|
Whether to apply FFT along sequence dimension. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
fft_norm
|
str
|
FFT normalization mode. |
"ortho"
|
Attributes:
| Name | Type | Description |
|---|---|---|
mix_features |
bool
|
Whether feature mixing is enabled. |
mix_sequence |
bool
|
Whether sequence mixing is enabled. |
fft1d |
FFT1D
|
1D FFT transform module. |
Methods:
| Name | Description |
|---|---|
forward |
Apply separable Fourier mixing. |
get_spectral_properties |
Get properties of separable mixing. |
Source code in spectrans/layers/mixing/fourier.py
Functions¶
forward ¶
Apply separable Fourier mixing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed tensor after applying selected transforms. |
Source code in spectrans/layers/mixing/fourier.py
get_spectral_properties ¶
Get properties of separable mixing.
Returns:
| Type | Description |
|---|---|
dict[str, str | bool]
|
Properties reflecting the separable nature. |
Source code in spectrans/layers/mixing/fourier.py
UnitaryMixingLayer ¶
UnitaryMixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001)
Bases: MixingLayer
Base class for unitary mixing operations.
Unitary mixing layers preserve energy and inner products, maintaining mathematical properties essential for stable training and theoretical guarantees in spectral transformers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
dropout
|
float
|
Dropout probability for regularization. |
0.0
|
norm_eps
|
float
|
Epsilon for numerical stability. |
1e-5
|
energy_tolerance
|
float
|
Tolerance for energy preservation verification. |
1e-4
|
Attributes:
| Name | Type | Description |
|---|---|---|
energy_tolerance |
float
|
Tolerance for energy preservation checks. |
Methods:
| Name | Description |
|---|---|
get_spectral_properties |
Get properties specific to unitary transforms. |
verify_energy_preservation |
Verify energy preservation (Parseval's theorem). |
verify_orthogonality |
Verify orthogonality of the transform matrix. |
Source code in spectrans/layers/mixing/base.py
Functions¶
get_spectral_properties ¶
Get properties specific to unitary transforms.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing unitary transform properties. |
Source code in spectrans/layers/mixing/base.py
verify_energy_preservation ¶
Verify energy preservation (Parseval's theorem).
Checks that \(||\mathbf{output}||^2 \approx ||\mathbf{input}||^2\) within tolerance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_tensor
|
Tensor
|
Input tensor before transformation. |
required |
output_tensor
|
Tensor
|
Output tensor after transformation. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if energy is preserved within tolerance. |
Source code in spectrans/layers/mixing/base.py
verify_orthogonality ¶
Verify orthogonality of the transform matrix.
Checks that \(\mathbf{T} \mathbf{T}^H \approx \mathbf{I}\) (identity matrix).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transform_matrix
|
Tensor
|
Transform matrix to verify. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if matrix is orthogonal within tolerance. |
Source code in spectrans/layers/mixing/base.py
WaveletMixing ¶
WaveletMixing(hidden_dim: int, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', dropout: float = 0.0)
Bases: Module
Token mixing layer using discrete wavelet transform.
Performs mixing in wavelet domain for multi-resolution processing. Decomposes input using DWT, applies learnable mixing to coefficients, and reconstructs the output with residual connections.
Mathematical Formulation
Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times D}\) where \(B\) is batch size, \(N\) is sequence length, and \(D\) is hidden dimension:
Step 1: Channel-wise Decomposition
For each channel \(d \in \{0, 1, \ldots, D-1\}\), extract the channel signal:
Apply \(J\)-level DWT decomposition:
Where: - \(\mathbf{c}_{A_J}^{(d)} \in \mathbb{R}^{B \times L_{A_J}}\) are approximation coefficients at level \(J\) - \(\mathbf{c}_{D_j}^{(d)} \in \mathbb{R}^{B \times L_{D_j}}\) are detail coefficients at level \(j\) - \(L_{A_J}\) and \(L_{D_j}\) are coefficient lengths after subsampling
Step 2: Learnable Mixing
Apply mixing transformations based on mode:
Pointwise Mixing (:code:mixing_mode='pointwise'):
Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times \max(L) \times D}\) are learnable parameters, and \(\odot\) denotes element-wise multiplication with broadcasting.
Channel Mixing (:code:mixing_mode='channel'):
Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times D \times D}\) are initialized as identity matrices.
Level Mixing (:code:mixing_mode='level'):
Cross-level attention is applied to all coefficients simultaneously:
Step 3: Reconstruction
Reconstruct the signal using inverse DWT:
Apply length adjustment if necessary:
Step 4: Residual Connection and Dropout
Combine all channels and apply residual connection:
Complexity Analysis
-
Time Complexity: \(O(NJ) + O(D \cdot N \log N)\) per forward pass
- \(O(N)\) for DWT/IDWT per level and channel (linear in signal length)
- \(O(DJ)\) for mixing operations across all levels and channels
- Dominated by DWT operations when \(J\) is small
-
Space Complexity: \(O(DN + P)\) where \(P\) is parameter count
- \(O(DN)\) for storing coefficient tensors
- Parameter count depends on mixing mode:
- Pointwise: \(P = O(LD)\) where \(L\) is max coefficient length
- Channel: \(P = O(JD^2)\)
- Level: \(P = O(D^2)\) for attention parameters
Implementation Notes
- Uses PyTorch-native DWT implementation for gradient compatibility
- Dynamic weight slicing ensures proper alignment with variable-length coefficients
- Perfect reconstruction property maintained through careful length handling
- Each channel processed independently for computational efficiency
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size \(D\). |
required |
wavelet
|
str
|
Wavelet type (e.g., 'db1', 'db4', 'sym2'). Determines filter bank characteristics. |
'db4'
|
levels
|
int
|
Number of decomposition levels \(J\). Controls resolution hierarchy. |
3
|
mixing_mode
|
str
|
Mixing strategy: 'pointwise' (element-wise), 'channel' (diagonal), 'level' (attention). |
'pointwise'
|
dropout
|
float
|
Dropout probability applied to mixed coefficients before residual connection. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
dwt |
DWT1D
|
Wavelet transform module implementing PyTorch-native DWT/IDWT. |
mixing_weights |
ParameterDict
|
Learnable parameters for coefficient mixing, structure depends on :attr: |
dropout |
Dropout
|
Dropout layer for regularization. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If :attr: |
Examples:
Basic usage with pointwise mixing:
>>> mixer = WaveletMixing(hidden_dim=256, wavelet='db4', levels=3)
>>> x = torch.randn(32, 128, 256) # (batch, seq_len, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Channel mixing with identity initialization:
>>> mixer = WaveletMixing(hidden_dim=64, mixing_mode='channel', levels=2)
>>> x = torch.randn(16, 64, 64)
>>> output = mixer(x)
>>> # Initially behaves like identity due to residual connection
Cross-level mixing with attention:
>>> mixer = WaveletMixing(hidden_dim=128, mixing_mode='level', levels=4)
>>> x = torch.randn(8, 256, 128)
>>> output = mixer(x) # Attention applied across wavelet levels
Methods:
| Name | Description |
|---|---|
forward |
Apply wavelet-based mixing following the mathematical formulation. |
from_config |
Create WaveletMixing from configuration. |
Source code in spectrans/layers/mixing/wavelet.py
Functions¶
forward ¶
Apply wavelet-based mixing following the mathematical formulation.
Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual. Each hidden dimension is processed independently to maintain channel separability.
Mathematical Implementation
The forward pass implements the mathematical formulation exactly:
- Channel Extraction: \(\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]\) for \(d = 0, \ldots, D-1\)
- Wavelet Decomposition: \(\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}\)
- Learnable Mixing: Apply mode-specific transformations to coefficients
- Signal Reconstruction: \(\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}\)
- Channel Concatenation: \(\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]\)
- Residual Connection: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape \((B, N, D)\) where:
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed output tensor of identical shape \((B, N, D)\) with wavelet-domain mixing applied and residual connection. |
Notes
- Dynamic coefficient length handling ensures robustness to varying sequence lengths
- Perfect reconstruction property maintained through careful padding/truncation
- Gradient flow preserved through PyTorch-native operations
Source code in spectrans/layers/mixing/wavelet.py
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 | |
from_config
classmethod
¶
from_config(config: WaveletMixingConfig) -> WaveletMixing
Create WaveletMixing from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletMixingConfig
|
Typed and validated configuration. |
required |
Returns:
| Type | Description |
|---|---|
WaveletMixing
|
Configured instance. |
Source code in spectrans/layers/mixing/wavelet.py
WaveletMixing2D ¶
WaveletMixing2D(channels: int, wavelet: WaveletType = 'db4', levels: int = 2, mixing_mode: str = 'subband')
Bases: Module
2D wavelet mixing layer for image-like data.
Performs mixing in 2D wavelet domain, suitable for vision transformers and other architectures processing 2D spatial data. Processes spatial information through multi-resolution wavelet subbands.
Mathematical Formulation
Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times C \times H \times W}\) where \(B\) is batch size, \(C\) is channels, \(H\) is height, and \(W\) is width:
Step 1: Channel-wise 2D Decomposition
For each channel \(c \in \{0, 1, \ldots, C-1\}\), extract spatial data:
Apply \(J\)-level 2D DWT decomposition:
Where: - \(\mathbf{LL}_J^{(c)} \in \mathbb{R}^{B \times H_J \times W_J}\) is the approximation subband (low-low) - \(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)} \in \mathbb{R}^{B \times H_j \times W_j}\) are detail subbands - \(H_j = \frac{H}{2^j}\), \(W_j = \frac{W}{2^j}\) are spatial dimensions at level \(j\)
Step 2: Subband Mixing
Apply mixing transformations based on mode:
Subband Mixing (:code:mixing_mode='subband'):
Independent processing of each subband using convolutional networks:
Where \(f_{\cdot}\) are learnable convolutional transformations.
Cross Mixing (:code:mixing_mode='cross'):
Cross-attention across all subbands:
Step 3: 2D Reconstruction
Reconstruct the spatial signal:
Step 4: Channel Concatenation and Residual
Complexity Analysis
- Time Complexity: \(O(CHW \cdot J) + O(\text{mixing operations})\)
- Space Complexity: \(O(CHW + \text{subband storage})\)
Where mixing complexity depends on mode: - Subband: \(O(\text{conv operations per subband})\) - Cross: \(O(\text{attention across subbands})\) - Attention: \(O(\text{transformer encoder})\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
int
|
Number of input/output channels \(C\). |
required |
wavelet
|
str
|
Wavelet type determining 2D filter bank characteristics. |
'db4'
|
levels
|
int
|
Number of decomposition levels \(J\). |
2
|
mixing_mode
|
str
|
Subband mixing strategy: 'subband' (independent), 'cross' (attention), 'attention' (transformer). |
'subband'
|
Attributes:
| Name | Type | Description |
|---|---|---|
dwt |
DWT2D
|
2D wavelet transform module. |
ll_mixer |
Sequential
|
Convolutional network for LL subband (subband mode). |
detail_mixers |
ModuleList
|
Convolutional networks for detail subbands (subband mode). |
cross_mixer |
MultiheadAttention
|
Cross-attention module (cross mode). |
subband_attention |
TransformerEncoder
|
Transformer encoder for subband attention (attention mode). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If :attr: |
Examples:
Independent subband processing:
>>> mixer = WaveletMixing2D(channels=256, wavelet='db4', levels=2)
>>> x = torch.randn(32, 256, 64, 64) # (batch, channels, height, width)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Cross-subband attention:
>>> mixer = WaveletMixing2D(channels=128, mixing_mode='cross', levels=3)
>>> x = torch.randn(16, 128, 128, 128)
>>> output = mixer(x) # Attention applied across all wavelet subbands
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D wavelet-based mixing following the mathematical formulation. |
from_config |
Create WaveletMixing2D from configuration. |
Source code in spectrans/layers/mixing/wavelet.py
Functions¶
forward ¶
Apply 2D wavelet-based mixing following the mathematical formulation.
Implements complete 2D wavelet mixing: spatial decomposition → subband mixing → reconstruction → residual connection. Each channel is processed independently.
Mathematical Implementation
- Channel Extraction: \(\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]\) for each channel \(c\)
- 2D Wavelet Decomposition: \(\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}\)
- Subband Mixing: Apply mode-specific transformations to wavelet subbands
- 2D Reconstruction: \(\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}\)
- Channel Stacking: \(\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]\)
- Residual Connection: \(\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape \((B, C, H, W)\) where:
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed output tensor of identical shape \((B, C, H, W)\) with 2D wavelet-domain mixing applied and residual connection. |
Notes
- Spatial dimensions preserved through careful reconstruction handling
- Different mixing strategies provide various inductive biases
- Subband mode: Independent processing emphasizes local features
- Cross mode: Attention enables global subband interactions
- Attention mode: Full transformer encoder for complex dependencies
Source code in spectrans/layers/mixing/wavelet.py
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 | |
from_config
classmethod
¶
from_config(config: WaveletMixing2DConfig) -> WaveletMixing2D
Create WaveletMixing2D from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletMixing2DConfig
|
Typed and validated configuration. |
required |
Returns:
| Type | Description |
|---|---|
WaveletMixing2D
|
Configured instance. |
Source code in spectrans/layers/mixing/wavelet.py
FNOBlock ¶
FNOBlock(hidden_dim: int, modes: int | tuple[int, ...] = 16, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0, norm_type: NormType = 'layernorm')
Bases: SpectralComponent
Complete FNO block with spectral convolution and feedforward network.
This block combines the FNO layer with layer normalization, residual connections, and an optional feedforward network for a complete transformer-like block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. |
required |
modes
|
int | tuple[int, ...]
|
Number of Fourier modes to retain. Default is 16. |
16
|
mlp_ratio
|
float
|
Expansion ratio for feedforward network. Default is 2.0. |
2.0
|
activation
|
str
|
Activation function. Default is 'gelu'. |
'gelu'
|
dropout
|
float
|
Dropout probability. Default is 0.0. |
0.0
|
norm_type
|
str
|
Normalization type: 'layer' or 'batch'. Default is 'layer'. |
'layernorm'
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension size. |
fno |
FourierNeuralOperator
|
FNO layer for spectral convolution. |
norm1 |
Module
|
First normalization layer. |
norm2 |
Module | None
|
Second normalization layer (if FFN is used). |
ffn |
Sequential | None
|
Feedforward network. |
dropout |
Dropout
|
Dropout layer. |
Examples:
>>> block = FNOBlock(hidden_dim=64, modes=16, mlp_ratio=2.0)
>>> x = torch.randn(32, 128, 64)
>>> output = block(x)
>>> assert output.shape == x.shape
Methods:
| Name | Description |
|---|---|
forward |
Apply FNO block. |
Source code in spectrans/layers/operators/fno.py
Functions¶
forward ¶
Apply FNO block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, sequence_length, hidden_dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of same shape as input. |
Source code in spectrans/layers/operators/fno.py
FourierNeuralOperator ¶
FourierNeuralOperator(hidden_dim: int, modes: int | tuple[int, ...] = 16, activation: ActivationType = 'gelu', use_spectral_conv: bool = True, use_linear: bool = True)
Bases: SpectralComponent
Fourier Neural Operator layer for learning operators in function spaces.
This layer combines spectral convolution with pointwise linear transformations to learn mappings between function spaces efficiently.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension (number of channels). |
required |
modes
|
int | tuple[int, ...]
|
Number of Fourier modes to retain. Can be an integer for 1D or tuple for higher dimensions. Default is 16. |
16
|
activation
|
str
|
Activation function. Options: 'gelu', 'relu', 'tanh'. Default is 'gelu'. |
'gelu'
|
use_spectral_conv
|
bool
|
Whether to use spectral convolution. Default is True. |
True
|
use_linear
|
bool
|
Whether to use pointwise linear transformation. Default is True. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_dim |
int
|
Hidden dimension size. |
modes |
int | tuple[int, ...]
|
Number of retained Fourier modes. |
spectral_conv |
SpectralConv1d | SpectralConv2d | None
|
Spectral convolution layer if enabled. |
linear |
Conv1d | Conv2d | None
|
Pointwise convolution layer if enabled. |
activation |
Module
|
Activation function. |
Examples:
>>> fno = FourierNeuralOperator(hidden_dim=64, modes=16)
>>> x = torch.randn(32, 128, 64) # (batch, sequence, channels)
>>> output = fno(x)
>>> assert output.shape == x.shape
Methods:
| Name | Description |
|---|---|
forward |
Apply Fourier Neural Operator. |
Source code in spectrans/layers/operators/fno.py
Functions¶
forward ¶
Apply Fourier Neural Operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. Shape depends on dimensionality: - 1D: (batch_size, sequence_length, hidden_dim) - 2D: (batch_size, height, width, hidden_dim) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of same shape as input. |
Source code in spectrans/layers/operators/fno.py
SpectralConv1d ¶
Bases: Module
1D Spectral convolution layer.
Performs convolution in the Fourier domain by element-wise multiplication with learnable complex-valued weights on truncated modes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels. |
required |
out_channels
|
int
|
Number of output channels. |
required |
modes
|
int
|
Number of Fourier modes to keep (frequency truncation). |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
in_channels |
int
|
Input channel count. |
out_channels |
int
|
Output channel count. |
modes |
int
|
Number of retained Fourier modes. |
weights |
Parameter
|
Complex-valued learnable weights of shape (in_channels, out_channels, modes). |
Examples:
>>> conv = SpectralConv1d(in_channels=64, out_channels=64, modes=16)
>>> x = torch.randn(32, 64, 128) # (batch, channels, sequence)
>>> output = conv(x)
>>> assert output.shape == x.shape
Methods:
| Name | Description |
|---|---|
forward |
Apply spectral convolution. |
Source code in spectrans/layers/operators/fno.py
Functions¶
forward ¶
Apply spectral convolution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, in_channels, sequence_length). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (batch_size, out_channels, sequence_length). |
Source code in spectrans/layers/operators/fno.py
SpectralConv2d ¶
Bases: Module
2D Spectral convolution layer.
Performs 2D convolution in the Fourier domain for image-like data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels. |
required |
out_channels
|
int
|
Number of output channels. |
required |
modes
|
tuple[int, int]
|
Number of Fourier modes to keep in each dimension (height, width). |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
in_channels |
int
|
Input channel count. |
out_channels |
int
|
Output channel count. |
modes1 |
int
|
Number of retained modes in first spatial dimension. |
modes2 |
int
|
Number of retained modes in second spatial dimension. |
weights |
Parameter
|
Complex weights of shape (in_channels, out_channels, modes1, modes2). |
Examples:
>>> conv2d = SpectralConv2d(in_channels=3, out_channels=64, modes=(32, 32))
>>> x = torch.randn(8, 3, 256, 256)
>>> output = conv2d(x)
>>> assert output.shape == (8, 64, 256, 256)
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D spectral convolution. |
Source code in spectrans/layers/operators/fno.py
Functions¶
forward ¶
Apply 2D spectral convolution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, in_channels, height, width). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (batch_size, out_channels, height, width). |