LST Models¶
spectrans.models.lst ¶
Linear Spectral Transform (LST) models using efficient transforms.
This module implements transformer models that use linear spectral transforms (DCT, DST, Hadamard) for sequence mixing instead of attention mechanisms. These models achieve \(O(n \log n)\) complexity through Fast Fourier-like algorithms, providing an efficient alternative to quadratic attention.
The LST mechanism applies learned transformations in the spectral domain, enabling global token interactions while maintaining computational efficiency.
Classes:
| Name | Description |
|---|---|
LSTTransformer |
Complete transformer model using linear spectral transforms. |
LSTEncoder |
Encoder-only model for representation learning. |
LSTDecoder |
Decoder model with causal masking support. |
Examples:
Basic LST transformer:
>>> import torch
>>> from spectrans.models.lst import LSTTransformer
>>> model = LSTTransformer(
... hidden_dim=512,
... num_layers=6,
... transform_type="dct",
... max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512) # (batch, seq_len, dim)
>>> output = model(inputs_embeds=x)
>>> assert output.shape == x.shape
Using with token inputs and classification:
>>> model = LSTTransformer(
... vocab_size=10000,
... hidden_dim=512,
... num_layers=6,
... transform_type="hadamard",
... num_classes=10,
... max_sequence_length=512
... )
>>> input_ids = torch.randint(0, 10000, (32, 100))
>>> logits = model(input_ids)
>>> assert logits.shape == (32, 10)
Causal decoder model:
>>> from spectrans.models.lst import LSTDecoder
>>> decoder = LSTDecoder(
... vocab_size=10000,
... hidden_dim=512,
... num_layers=12,
... transform_type="dst",
... causal=True,
... max_sequence_length=2048
... )
Notes
Mathematical Foundation:
The LST mechanism replaces attention with spectral domain operations:
Where: - \(\mathcal{T}\) is the forward spectral transform (DCT/DST/Hadamard) - \(\mathcal{T}^{-1}\) is the inverse transform - \(\mathbf{W}\) is a learned spectral weighting matrix - \(\odot\) denotes element-wise multiplication
The transforms have efficient \(O(n \log n)\) implementations:
- DCT (Discrete Cosine Transform):
$$ X_k = \sum_{n=0}^{N-1} x_n \cos\left(\frac{\pi k(2n+1)}{2N}\right) $$
- DST (Discrete Sine Transform):
$$ X_k = \sum_{n=0}^{N-1} x_n \sin\left(\frac{\pi (k+1)(n+1)}{N+1}\right) $$
- Hadamard Transform:
$$ H_N = H_2 \otimes H_{\frac{N}{2}} = \begin{bmatrix} H_{\frac{N}{2}} & H_{\frac{N}{2}} \ H_{\frac{N}{2}} & -H_{\frac{N}{2}} \end{bmatrix} $$
The spectral weights enable frequency-selective filtering, allowing the model to learn which frequency components are important for the task.
References
James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.
Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. 2021. Long range arena: A benchmark for efficient transformers. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021).
Nasir Ahmed, T. Natarajan, and Kamisetty R. Rao. 1974. Discrete cosine transform. IEEE Transactions on Computers, C-23(1):90-93.
See Also
spectrans.layers.mixing : Spectral mixing layer implementations. spectrans.transforms.spectral : Core spectral transform implementations. spectrans.models.spectral_attention : Spectral attention models for comparison.
Classes¶
LSTTransformer ¶
LSTTransformer(vocab_size: int | None = None, hidden_dim: int = 512, num_layers: int = 6, max_sequence_length: int = 1024, transform_type: TransformLSTType = 'dct', use_conv_bias: bool = True, num_classes: int | None = None, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)
Bases: BaseModel
Linear Spectral Transform transformer model.
This model uses linear spectral transforms (DCT/DST/Hadamard) for sequence mixing, achieving O(n log n) complexity through fast transform algorithms. The model applies learned transformations in the spectral domain for efficient global token interactions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs. |
None
|
hidden_dim
|
int
|
Hidden dimension size for the model. |
512
|
num_layers
|
int
|
Number of transformer blocks. |
6
|
max_sequence_length
|
int
|
Maximum sequence length the model can process. |
1024
|
transform_type
|
TransformLSTType
|
Type of spectral transform to use. |
"dct"
|
use_conv_bias
|
bool
|
Whether to use bias in spectral convolution. |
True
|
num_classes
|
int | None
|
Number of output classes for classification. |
None
|
ffn_hidden_dim
|
int | None
|
Hidden dimension of the feedforward network. Default is 4 * hidden_dim. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding ("sinusoidal", "learned", "rotary", "alibi", or "none"). |
"sinusoidal"
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing to save memory. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
blocks |
ModuleList
|
Stack of LST transformer blocks. |
Examples:
>>> model = LSTTransformer(
... hidden_dim=512,
... num_layers=6,
... transform_type="dct",
... max_sequence_length=1024
... )
>>> x = torch.randn(32, 100, 512)
>>> output = model(inputs_embeds=x)
>>> assert output.shape == x.shape
Methods:
| Name | Description |
|---|---|
build_blocks |
Build transformer blocks with LST layers. |
from_config |
Create model from configuration. |
Source code in spectrans/models/lst.py
Functions¶
build_blocks ¶
Build transformer blocks with LST layers.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of LST transformer blocks. |
Source code in spectrans/models/lst.py
from_config
classmethod
¶
from_config(config: LSTModelConfig) -> LSTTransformer
Create model from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
LSTModelConfig
|
Model configuration object. |
required |
Returns:
| Type | Description |
|---|---|
LSTTransformer
|
Configured model instance. |
Source code in spectrans/models/lst.py
LSTEncoder ¶
LSTEncoder(vocab_size: int | None = None, hidden_dim: int = 512, num_layers: int = 6, max_sequence_length: int = 1024, transform_type: TransformLSTType = 'dct', use_conv_bias: bool = True, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal')
Bases: BaseModel
Encoder-only LST model for representation learning.
This model uses linear spectral transforms without a classification head, suitable for generating embeddings or as a component in larger architectures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. |
None
|
hidden_dim
|
int
|
Hidden dimension size. |
512
|
num_layers
|
int
|
Number of transformer blocks. |
6
|
max_sequence_length
|
int
|
Maximum sequence length. |
1024
|
transform_type
|
TransformLSTType
|
Type of spectral transform. |
"dct"
|
use_conv_bias
|
bool
|
Use bias in spectral convolution. |
True
|
ffn_hidden_dim
|
int | None
|
FFN hidden dimension. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_positional_encoding
|
bool
|
Use positional encoding. |
True
|
positional_encoding_type
|
str
|
Positional encoding type. |
"sinusoidal"
|
Methods:
| Name | Description |
|---|---|
build_blocks |
Build encoder blocks with LST layers. |
Source code in spectrans/models/lst.py
Functions¶
build_blocks ¶
Build encoder blocks with LST layers.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of LST encoder blocks. |
Source code in spectrans/models/lst.py
LSTDecoder ¶
LSTDecoder(vocab_size: int, hidden_dim: int = 512, num_layers: int = 12, max_sequence_length: int = 2048, transform_type: TransformLSTType = 'dst', causal: bool = True, use_conv_bias: bool = True, ffn_hidden_dim: int | None = None, dropout: float = 0.0, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', gradient_checkpointing: bool = False)
Bases: BaseModel
Decoder LST model with optional causal masking.
This model uses linear spectral transforms with support for causal masking, suitable for autoregressive generation tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int
|
Size of the vocabulary. |
required |
hidden_dim
|
int
|
Hidden dimension size. |
512
|
num_layers
|
int
|
Number of transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
2048
|
transform_type
|
TransformLSTType
|
Type of spectral transform (DST is preferred for causal). |
"dst"
|
causal
|
bool
|
Whether to use causal masking. |
True
|
use_conv_bias
|
bool
|
Use bias in spectral convolution. |
True
|
ffn_hidden_dim
|
int | None
|
FFN hidden dimension. |
None
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_positional_encoding
|
bool
|
Use positional encoding. |
True
|
positional_encoding_type
|
str
|
Positional encoding type. |
"sinusoidal"
|
gradient_checkpointing
|
bool
|
Use gradient checkpointing. |
False
|
Examples:
>>> decoder = LSTDecoder(
... vocab_size=10000,
... hidden_dim=512,
... num_layers=12,
... transform_type="dst",
... causal=True,
... max_sequence_length=2048
... )
>>> input_ids = torch.randint(0, 10000, (32, 100))
>>> logits = decoder(input_ids)
>>> assert logits.shape == (32, 100, 10000)
Methods:
| Name | Description |
|---|---|
build_blocks |
Build decoder blocks with causal LST layers. |
forward |
Forward pass through the decoder. |
Source code in spectrans/models/lst.py
Functions¶
build_blocks ¶
Build decoder blocks with causal LST layers.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of causal LST decoder blocks. |
Source code in spectrans/models/lst.py
forward ¶
forward(input_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, attention_mask: Tensor | None = None) -> Tensor
Forward pass through the decoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Input token IDs of shape (batch_size, sequence_length). |
None
|
inputs_embeds
|
Tensor | None
|
Pre-embedded inputs of shape (batch_size, sequence_length, hidden_dim). |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Language modeling logits of shape (batch_size, sequence_length, vocab_size). |