Wavelet Transformer¶
spectrans.models.wavenet_transformer ¶
Wavelet transformer models using discrete wavelet transforms.
This module implements transformer architectures that replace standard attention mechanisms with discrete wavelet transforms (DWT) for sequence mixing. The DWT provides multi-resolution analysis, decomposing sequences into approximation and detail coefficients at different scales while maintaining perfect reconstruction.
Classes:
| Name | Description |
|---|---|
WaveletTransformer |
Full wavelet transformer with DWT-based sequence mixing. |
WaveletEncoder |
Encoder-only variant for representation learning tasks. |
WaveletDecoder |
Decoder variant for sequence generation with causal processing. |
Examples:
Basic wavelet transformer for classification:
>>> import torch
>>> from spectrans.models.wavenet_transformer import WaveletTransformer
>>> model = WaveletTransformer(
... vocab_size=30000,
... hidden_dim=768,
... num_layers=12,
... wavelet='db4',
... levels=3,
... max_sequence_length=512,
... num_classes=10
... )
>>> input_ids = torch.randint(0, 30000, (32, 512))
>>> logits = model(input_ids)
>>> assert logits.shape == (32, 10)
Using different wavelet families:
>>> model_db = WaveletTransformer(
... hidden_dim=512,
... wavelet='db8',
... levels=4
... )
>>> model_sym = WaveletTransformer(
... hidden_dim=512,
... wavelet='sym6',
... levels=3
... )
Encoder for feature extraction:
>>> from spectrans.models.wavenet_transformer import WaveletEncoder
>>> encoder = WaveletEncoder(
... hidden_dim=768,
... num_layers=6,
... wavelet='coif3',
... levels=2
... )
>>> embeddings = torch.randn(32, 100, 768)
>>> features = encoder(inputs_embeds=embeddings)
Notes
Mathematical Foundation:
The discrete wavelet transform decomposes a signal \(\mathbf{x} \in \mathbb{R}^n\) into a multi-scale representation. For \(J\) decomposition levels, the DWT produces:
where \(\mathbf{c}_{A_J} \in \mathbb{R}^{\frac{n}{2^J}}\) are approximation coefficients at the coarsest level and \(\mathbf{c}_{D_j} \in \mathbb{R}^{\frac{n}{2^j}}\) are detail coefficients at level \(j\).
The decomposition employs convolution with filter banks:
where \(h\) and \(g\) are the low-pass and high-pass analysis filters. Perfect reconstruction is guaranteed by the synthesis filters satisfying:
where \(\phi_{J,k}\) and \(\psi_{j,k}\) are scaling and wavelet functions.
Transformer Block Structure:
Each wavelet transformer block applies the DWT mixing with residual connections:
where the wavelet mixing operation processes each channel of the hidden representation independently through the DWT/IDWT pipeline.
Complexity Analysis:
For a sequence of length \(n\) with hidden dimension \(d\) and \(L\) layers: - Time complexity: \(O(L \cdot n \cdot d \cdot J)\) where \(J\) is decomposition levels - Space complexity: \(O(L \cdot n \cdot d)\) - Single DWT operation: \(O(n)\) per channel due to fast convolution algorithms
The linear complexity per channel makes wavelet mixing more efficient than quadratic attention mechanisms for long sequences.
References
Stéphane Mallat. 1999. A Wavelet Tour of Signal Processing, 2nd edition. Academic Press, San Diego.
Ingrid Daubechies. 1992. Ten Lectures on Wavelets. CBMS-NSF Regional Conference Series in Applied Mathematics, Vol. 61. SIAM, Philadelphia.
Martin Vetterli and Jelena Kovačević. 1995. Wavelets and Subband Coding. Prentice Hall, Englewood Cliffs.
See Also
spectrans.layers.mixing.wavelet : Wavelet mixing layer implementation. spectrans.transforms.wavelet : DWT transform implementations.
Classes¶
WaveletTransformer ¶
WaveletTransformer(vocab_size: int | None = None, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', num_classes: int | None = None, use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, output_type: OutputHeadType = 'classification', gradient_checkpointing: bool = False)
Bases: BaseModel
Wavelet transformer with DWT-based sequence mixing.
This model replaces attention mechanisms with discrete wavelet transforms, providing multi-resolution analysis of sequences with \(O(n)\) complexity per channel. The DWT decomposes input sequences into approximation and detail coefficients at multiple scales, representing both local transients and global structure.
The wavelet mixing operation applies the DWT along the sequence dimension for each channel independently, processes the coefficients through learnable transformations, and reconstructs the sequence via the inverse DWT (IDWT). Perfect reconstruction is maintained when no coefficient modification occurs.
For input :math:\mathbf{X} \in \mathbb{R}^{n \times d}, each channel undergoes:
.. math:: \mathbf{c} = \text{DWT}J(\mathbf{X} i \in [1,d]}) \quad \text{for
.. math:: \tilde{\mathbf{c}} = f_{\theta}(\mathbf{c})
.. math:: \mathbf{Y}_{:,i} = \text{IDWT}_J(\tilde{\mathbf{c}})
where :math:f_{\theta} represents learnable coefficient transformations and :math:J
is the number of decomposition levels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int | None
|
Size of the vocabulary for token embeddings. If None, expects pre-embedded inputs. |
None
|
hidden_dim
|
int
|
Hidden dimension size. |
768
|
num_layers
|
int
|
Number of wavelet transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length the model can process. |
512
|
wavelet
|
WaveletType
|
Type of wavelet to use (e.g., 'db4', 'sym6', 'coif3'). |
'db4'
|
levels
|
int
|
Number of wavelet decomposition levels. |
3
|
mixing_mode
|
str
|
How to mix wavelet coefficients: 'pointwise', 'channel', or 'level'. |
'pointwise'
|
num_classes
|
int | None
|
Number of output classes for classification. |
None
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. If None, defaults to 4 * hidden_dim. |
None
|
norm_eps
|
float
|
Epsilon for layer normalization. |
1e-12
|
output_type
|
OutputHeadType
|
Type of output head. |
'classification'
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing for memory efficiency. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
wavelet |
WaveletType
|
The wavelet family being used. |
levels |
int
|
Number of decomposition levels. |
mixing_mode |
str
|
Coefficient mixing strategy. |
blocks |
ModuleList
|
List of wavelet transformer blocks. |
Methods:
| Name | Description |
|---|---|
build_blocks |
Build wavelet transformer blocks. |
from_config |
Create wavelet transformer from configuration. |
Source code in spectrans/models/wavenet_transformer.py
Functions¶
build_blocks ¶
Build wavelet transformer blocks.
Returns:
| Type | Description |
|---|---|
ModuleList
|
List of wavelet transformer blocks with DWT mixing layers. |
Source code in spectrans/models/wavenet_transformer.py
from_config
classmethod
¶
from_config(config: WaveletTransformerConfig) -> WaveletTransformer
Create wavelet transformer from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletTransformerConfig
|
Configuration object with model parameters. |
required |
Returns:
| Type | Description |
|---|---|
WaveletTransformer
|
Configured wavelet transformer model. |
Source code in spectrans/models/wavenet_transformer.py
WaveletEncoder ¶
WaveletEncoder(hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, gradient_checkpointing: bool = False)
Bases: WaveletTransformer
Encoder-only wavelet transformer for representation learning.
This variant is designed for extracting representations from sequences using wavelet-based mixing, without any task-specific output head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size. |
768
|
num_layers
|
int
|
Number of wavelet transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
512
|
wavelet
|
WaveletType
|
Type of wavelet to use. |
'db4'
|
levels
|
int
|
Number of decomposition levels. |
3
|
mixing_mode
|
str
|
Coefficient mixing strategy. |
'pointwise'
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. |
None
|
norm_eps
|
float
|
Layer normalization epsilon. |
1e-12
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|
Source code in spectrans/models/wavenet_transformer.py
WaveletDecoder ¶
WaveletDecoder(vocab_size: int, hidden_dim: int = 768, num_layers: int = 12, max_sequence_length: int = 512, wavelet: WaveletType = 'db4', levels: int = 2, mixing_mode: str = 'pointwise', use_positional_encoding: bool = True, positional_encoding_type: PositionalEncodingType = 'sinusoidal', dropout: float = 0.1, ffn_hidden_dim: int | None = None, norm_eps: float = 1e-12, gradient_checkpointing: bool = False)
Bases: WaveletTransformer
Decoder wavelet transformer for sequence generation.
This variant uses causal wavelet processing suitable for autoregressive generation tasks. The wavelet decomposition is modified to respect causality constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int
|
Size of the vocabulary for token generation. |
required |
hidden_dim
|
int
|
Hidden dimension size. |
768
|
num_layers
|
int
|
Number of wavelet transformer blocks. |
12
|
max_sequence_length
|
int
|
Maximum sequence length. |
512
|
wavelet
|
WaveletType
|
Type of wavelet to use. |
'db4'
|
levels
|
int
|
Number of decomposition levels (typically lower for causality). |
2
|
mixing_mode
|
str
|
Coefficient mixing strategy. |
'pointwise'
|
use_positional_encoding
|
bool
|
Whether to use positional encoding. |
True
|
positional_encoding_type
|
PositionalEncodingType
|
Type of positional encoding. |
'sinusoidal'
|
dropout
|
float
|
Dropout probability. |
0.1
|
ffn_hidden_dim
|
int | None
|
Hidden dimension for FFN. |
None
|
norm_eps
|
float
|
Layer normalization epsilon. |
1e-12
|
gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. |
False
|