LST Attention¶
spectrans.layers.attention.lst ¶
Linear Spectral Transform (LST) attention mechanisms.
Implements attention mechanisms based on Linear Spectral Transforms including Discrete Cosine Transform (DCT), Discrete Sine Transform (DST), and Hadamard Transform. These transforms provide \(O(n \log n)\) attention computation with orthogonality properties.
LST attention replaces the standard \(\mathbf{Q}\mathbf{K}^T\) computation with element-wise multiplication in the transform domain, reducing computational complexity for long sequences.
Classes:
| Name | Description |
|---|---|
LSTAttention |
Linear Spectral Transform attention with various transform options. |
DCTAttention |
Attention using Discrete Cosine Transform. |
HadamardAttention |
Attention using fast Hadamard transform. |
Examples:
Basic LST attention with DCT:
>>> import torch
>>> from spectrans.layers.attention.lst import LSTAttention
>>> attn = LSTAttention(
... hidden_dim=512,
... num_heads=8,
... transform_type="dct"
... )
>>> x = torch.randn(32, 100, 512)
>>> output = attn(x)
>>> assert output.shape == x.shape
Multi-transform attention:
>>> from spectrans.layers.attention.lst import LSTAttention
>>> attn = LSTAttention(
... hidden_dim=512,
... num_heads=8,
... transform_type="mixed", # Uses different transforms per head
... )
>>> output = attn(x)
Notes
The LST attention computes:
Where \(T\) is an orthogonal transform (DCT, DST, Hadamard), \(\mathbf{\Lambda}\) is a learnable diagonal scaling matrix, and \(\odot\) denotes element-wise multiplication.
Standard attention has \(O(n^2d)\) complexity while LST attention reduces this to \(O(nd \log n)\). The orthogonality of transforms preserves information while computing in the frequency domain.
References
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. 2021. Long range arena: A benchmark for efficient transformers. In Proceedings of the International Conference on Learning Representations (ICLR).
See Also
spectrans.transforms.cosine : DCT/DST implementations. spectrans.transforms.hadamard : Hadamard transform. spectrans.layers.attention.spectral : Spectral attention with RFF.
Classes¶
LSTAttention ¶
LSTAttention(hidden_dim: int, num_heads: int = 8, transform_type: Literal['dct', 'dst', 'hadamard', 'mixed'] = 'dct', learnable_scale: bool = True, normalize: bool = True, dropout: float = 0.0, use_bias: bool = True)
Bases: AttentionLayer
Linear Spectral Transform attention mechanism.
Implements attention using orthogonal transforms (DCT, DST, Hadamard) with learnable diagonal scaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension of the model. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
transform_type
|
Literal['dct', 'dst', 'hadamard', 'mixed']
|
Type of transform to use. "mixed" uses different transforms per head. |
"dct"
|
learnable_scale
|
bool
|
Whether to use learnable diagonal scaling matrix. |
True
|
normalize
|
bool
|
Whether to normalize in transform domain. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_bias
|
bool
|
Whether to use bias in projections. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
head_dim |
int
|
Dimension per attention head. |
transform_type |
str
|
Type of transform being used. |
transforms |
ModuleList
|
List of transforms (one per head if mixed). |
scale |
Parameter | None
|
Learnable diagonal scaling if enabled. |
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of LST attention. |
Source code in spectrans/layers/attention/lst.py
Functions¶
forward ¶
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of LST attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch_size, seq_len, hidden_dim). |
required |
mask
|
Tensor | None
|
Attention mask of shape (batch_size, seq_len). |
None
|
return_attention
|
bool
|
Whether to return attention weights (not supported). |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, returns (output, None). |
Source code in spectrans/layers/attention/lst.py
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | |
DCTAttention ¶
DCTAttention(hidden_dim: int, num_heads: int = 8, dct_type: int = 2, learnable_scale: bool = True, dropout: float = 0.0)
Bases: LSTAttention
Attention using Discrete Cosine Transform.
Specialized LST attention that uses DCT for all heads for real-valued signals with energy compaction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
dct_type
|
int
|
DCT type (2 is most common). |
2
|
learnable_scale
|
bool
|
Whether to use learnable scaling. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
Source code in spectrans/layers/attention/lst.py
HadamardAttention ¶
HadamardAttention(hidden_dim: int, num_heads: int = 8, scale_by_sqrt: bool = True, learnable_scale: bool = True, dropout: float = 0.0)
Bases: LSTAttention
Attention using fast Hadamard transform.
Uses Hadamard transform for \(O(n \log n)\) attention computation with binary coefficients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads. |
8
|
scale_by_sqrt
|
bool
|
Whether to scale by sqrt(n) for orthogonality. |
True
|
learnable_scale
|
bool
|
Whether to use learnable diagonal scaling. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
Source code in spectrans/layers/attention/lst.py
MixedSpectralAttention ¶
MixedSpectralAttention(hidden_dim: int, num_heads: int = 9, use_fft: bool = True, use_dct: bool = True, use_hadamard: bool = True, dropout: float = 0.0)
Bases: AttentionLayer
Mixed spectral attention using multiple transform types.
Combines different spectral transforms across heads for diverse frequency representations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension. |
required |
num_heads
|
int
|
Number of attention heads (should be divisible by 3 for even split). |
8
|
use_fft
|
bool
|
Whether to include FFT heads. |
True
|
use_dct
|
bool
|
Whether to include DCT heads. |
True
|
use_hadamard
|
bool
|
Whether to include Hadamard heads. |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
Methods:
| Name | Description |
|---|---|
forward |
Forward pass of mixed spectral attention. |
Source code in spectrans/layers/attention/lst.py
Functions¶
forward ¶
forward(x: Tensor, _mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]
Forward pass of mixed spectral attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input of shape (batch_size, seq_len, hidden_dim). |
required |
_mask
|
Tensor | None
|
Attention mask (not implemented for spectral attention). |
None
|
return_attention
|
bool
|
Whether to return attention weights. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor or tuple[Tensor, Tensor]
|
Output and optionally None for weights. |
Source code in spectrans/layers/attention/lst.py
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 | |