Wavelet Mixing¶
spectrans.layers.mixing.wavelet ¶
Wavelet-based mixing layers for spectral transformer networks.
Implements neural network layers that perform token mixing operations through discrete wavelet transforms (DWT). The wavelet domain provides decomposition of signals into approximation and detail coefficients at multiple resolution levels, structuring processing of different frequency components.
Wavelet mixing layers apply learnable transformations to wavelet coefficients before reconstruction, providing an alternative to attention mechanisms with different inductive biases. The multi-scale nature of wavelets suits signals with hierarchical structure.
Classes:
| Name | Description |
|---|---|
WaveletMixing |
1D wavelet mixing layer using discrete wavelet transform. |
WaveletMixing2D |
2D wavelet mixing layer for image-like data processing. |
Examples:
Basic 1D wavelet mixing:
>>> import torch
>>> from spectrans.layers.mixing.wavelet import WaveletMixing
>>> mixer = WaveletMixing(hidden_dim=256, wavelet='db4', levels=3)
>>> x = torch.randn(32, 128, 256)
>>> output = mixer(x)
>>> assert output.shape == x.shape
2D wavelet mixing for spatial data:
>>> from spectrans.layers.mixing.wavelet import WaveletMixing2D
>>> mixer_2d = WaveletMixing2D(channels=256, wavelet='db4', levels=2)
>>> x = torch.randn(32, 256, 64, 64)
>>> output = mixer_2d(x)
Notes
Mathematical Foundation:
The discrete wavelet transform decomposes a signal \(\mathbf{x}\) into approximation coefficients \(\mathbf{c}_A\) and detail coefficients \(\{\mathbf{c}_{D_j}\}_{j=1}^J\) at \(J\) levels:
The decomposition uses filter banks with low-pass filter \(\mathbf{h}\) and high-pass filter \(\mathbf{g}\):
Wavelet mixing applies learnable transformations to these coefficients through pointwise mixing with element-wise scaling, channel mixing with linear transformations across feature dimensions, and level mixing with cross-scale interactions using attention mechanisms.
Time complexity is \(O(nd)\) for \(n\)-length signals with \(d\) channels. Space complexity is \(O(nd)\) for coefficient storage. Decomposition typically uses 1-5 levels depending on signal length.
Daubechies wavelets provide compact support with localization. Symlets are symmetric with reduced phase distortion. Coiflets balance time-frequency resolution. Biorthogonal wavelets enable perfect reconstruction with linear phase.
All wavelet operations maintain gradient flow for end-to-end training. The transforms use PyTorch-native implementations compatible with automatic differentiation, avoiding external library dependencies that could break gradient computation.
References
Ingrid Daubechies. 1992. Ten Lectures on Wavelets. SIAM, Philadelphia.
Stéphane Mallat. 2009. A Wavelet Tour of Signal Processing: The Sparse Way, 3rd edition. Academic Press, Boston.
See Also
spectrans.transforms.wavelet : Underlying DWT implementations spectrans.layers.mixing.base : Base mixing layer interfaces spectrans.layers.mixing.fourier : Fourier-based mixing alternatives
Classes¶
WaveletMixing ¶
WaveletMixing(hidden_dim: int, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', dropout: float = 0.0)
Bases: Module
Token mixing layer using discrete wavelet transform.
Performs mixing in wavelet domain for multi-resolution processing. Decomposes input using DWT, applies learnable mixing to coefficients, and reconstructs the output with residual connections.
Mathematical Formulation
Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times D}\) where \(B\) is batch size, \(N\) is sequence length, and \(D\) is hidden dimension:
Step 1: Channel-wise Decomposition
For each channel \(d \in \{0, 1, \ldots, D-1\}\), extract the channel signal:
Apply \(J\)-level DWT decomposition:
Where: - \(\mathbf{c}_{A_J}^{(d)} \in \mathbb{R}^{B \times L_{A_J}}\) are approximation coefficients at level \(J\) - \(\mathbf{c}_{D_j}^{(d)} \in \mathbb{R}^{B \times L_{D_j}}\) are detail coefficients at level \(j\) - \(L_{A_J}\) and \(L_{D_j}\) are coefficient lengths after subsampling
Step 2: Learnable Mixing
Apply mixing transformations based on mode:
Pointwise Mixing (:code:mixing_mode='pointwise'):
Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times \max(L) \times D}\) are learnable parameters, and \(\odot\) denotes element-wise multiplication with broadcasting.
Channel Mixing (:code:mixing_mode='channel'):
Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times D \times D}\) are initialized as identity matrices.
Level Mixing (:code:mixing_mode='level'):
Cross-level attention is applied to all coefficients simultaneously:
Step 3: Reconstruction
Reconstruct the signal using inverse DWT:
Apply length adjustment if necessary:
Step 4: Residual Connection and Dropout
Combine all channels and apply residual connection:
Complexity Analysis
-
Time Complexity: \(O(NJ) + O(D \cdot N \log N)\) per forward pass
- \(O(N)\) for DWT/IDWT per level and channel (linear in signal length)
- \(O(DJ)\) for mixing operations across all levels and channels
- Dominated by DWT operations when \(J\) is small
-
Space Complexity: \(O(DN + P)\) where \(P\) is parameter count
- \(O(DN)\) for storing coefficient tensors
- Parameter count depends on mixing mode:
- Pointwise: \(P = O(LD)\) where \(L\) is max coefficient length
- Channel: \(P = O(JD^2)\)
- Level: \(P = O(D^2)\) for attention parameters
Implementation Notes
- Uses PyTorch-native DWT implementation for gradient compatibility
- Dynamic weight slicing ensures proper alignment with variable-length coefficients
- Perfect reconstruction property maintained through careful length handling
- Each channel processed independently for computational efficiency
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Hidden dimension size \(D\). |
required |
wavelet
|
str
|
Wavelet type (e.g., 'db1', 'db4', 'sym2'). Determines filter bank characteristics. |
'db4'
|
levels
|
int
|
Number of decomposition levels \(J\). Controls resolution hierarchy. |
3
|
mixing_mode
|
str
|
Mixing strategy: 'pointwise' (element-wise), 'channel' (diagonal), 'level' (attention). |
'pointwise'
|
dropout
|
float
|
Dropout probability applied to mixed coefficients before residual connection. |
0.0
|
Attributes:
| Name | Type | Description |
|---|---|---|
dwt |
DWT1D
|
Wavelet transform module implementing PyTorch-native DWT/IDWT. |
mixing_weights |
ParameterDict
|
Learnable parameters for coefficient mixing, structure depends on :attr: |
dropout |
Dropout
|
Dropout layer for regularization. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If :attr: |
Examples:
Basic usage with pointwise mixing:
>>> mixer = WaveletMixing(hidden_dim=256, wavelet='db4', levels=3)
>>> x = torch.randn(32, 128, 256) # (batch, seq_len, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Channel mixing with identity initialization:
>>> mixer = WaveletMixing(hidden_dim=64, mixing_mode='channel', levels=2)
>>> x = torch.randn(16, 64, 64)
>>> output = mixer(x)
>>> # Initially behaves like identity due to residual connection
Cross-level mixing with attention:
>>> mixer = WaveletMixing(hidden_dim=128, mixing_mode='level', levels=4)
>>> x = torch.randn(8, 256, 128)
>>> output = mixer(x) # Attention applied across wavelet levels
Methods:
| Name | Description |
|---|---|
forward |
Apply wavelet-based mixing following the mathematical formulation. |
from_config |
Create WaveletMixing from configuration. |
Source code in spectrans/layers/mixing/wavelet.py
Functions¶
forward ¶
Apply wavelet-based mixing following the mathematical formulation.
Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual. Each hidden dimension is processed independently to maintain channel separability.
Mathematical Implementation
The forward pass implements the mathematical formulation exactly:
- Channel Extraction: \(\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]\) for \(d = 0, \ldots, D-1\)
- Wavelet Decomposition: \(\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}\)
- Learnable Mixing: Apply mode-specific transformations to coefficients
- Signal Reconstruction: \(\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}\)
- Channel Concatenation: \(\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]\)
- Residual Connection: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape \((B, N, D)\) where:
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed output tensor of identical shape \((B, N, D)\) with wavelet-domain mixing applied and residual connection. |
Notes
- Dynamic coefficient length handling ensures robustness to varying sequence lengths
- Perfect reconstruction property maintained through careful padding/truncation
- Gradient flow preserved through PyTorch-native operations
Source code in spectrans/layers/mixing/wavelet.py
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 | |
from_config
classmethod
¶
from_config(config: WaveletMixingConfig) -> WaveletMixing
Create WaveletMixing from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletMixingConfig
|
Typed and validated configuration. |
required |
Returns:
| Type | Description |
|---|---|
WaveletMixing
|
Configured instance. |
Source code in spectrans/layers/mixing/wavelet.py
WaveletMixing2D ¶
WaveletMixing2D(channels: int, wavelet: WaveletType = 'db4', levels: int = 2, mixing_mode: str = 'subband')
Bases: Module
2D wavelet mixing layer for image-like data.
Performs mixing in 2D wavelet domain, suitable for vision transformers and other architectures processing 2D spatial data. Processes spatial information through multi-resolution wavelet subbands.
Mathematical Formulation
Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times C \times H \times W}\) where \(B\) is batch size, \(C\) is channels, \(H\) is height, and \(W\) is width:
Step 1: Channel-wise 2D Decomposition
For each channel \(c \in \{0, 1, \ldots, C-1\}\), extract spatial data:
Apply \(J\)-level 2D DWT decomposition:
Where: - \(\mathbf{LL}_J^{(c)} \in \mathbb{R}^{B \times H_J \times W_J}\) is the approximation subband (low-low) - \(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)} \in \mathbb{R}^{B \times H_j \times W_j}\) are detail subbands - \(H_j = \frac{H}{2^j}\), \(W_j = \frac{W}{2^j}\) are spatial dimensions at level \(j\)
Step 2: Subband Mixing
Apply mixing transformations based on mode:
Subband Mixing (:code:mixing_mode='subband'):
Independent processing of each subband using convolutional networks:
Where \(f_{\cdot}\) are learnable convolutional transformations.
Cross Mixing (:code:mixing_mode='cross'):
Cross-attention across all subbands:
Step 3: 2D Reconstruction
Reconstruct the spatial signal:
Step 4: Channel Concatenation and Residual
Complexity Analysis
- Time Complexity: \(O(CHW \cdot J) + O(\text{mixing operations})\)
- Space Complexity: \(O(CHW + \text{subband storage})\)
Where mixing complexity depends on mode: - Subband: \(O(\text{conv operations per subband})\) - Cross: \(O(\text{attention across subbands})\) - Attention: \(O(\text{transformer encoder})\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
int
|
Number of input/output channels \(C\). |
required |
wavelet
|
str
|
Wavelet type determining 2D filter bank characteristics. |
'db4'
|
levels
|
int
|
Number of decomposition levels \(J\). |
2
|
mixing_mode
|
str
|
Subband mixing strategy: 'subband' (independent), 'cross' (attention), 'attention' (transformer). |
'subband'
|
Attributes:
| Name | Type | Description |
|---|---|---|
dwt |
DWT2D
|
2D wavelet transform module. |
ll_mixer |
Sequential
|
Convolutional network for LL subband (subband mode). |
detail_mixers |
ModuleList
|
Convolutional networks for detail subbands (subband mode). |
cross_mixer |
MultiheadAttention
|
Cross-attention module (cross mode). |
subband_attention |
TransformerEncoder
|
Transformer encoder for subband attention (attention mode). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If :attr: |
Examples:
Independent subband processing:
>>> mixer = WaveletMixing2D(channels=256, wavelet='db4', levels=2)
>>> x = torch.randn(32, 256, 64, 64) # (batch, channels, height, width)
>>> output = mixer(x)
>>> assert output.shape == x.shape
Cross-subband attention:
>>> mixer = WaveletMixing2D(channels=128, mixing_mode='cross', levels=3)
>>> x = torch.randn(16, 128, 128, 128)
>>> output = mixer(x) # Attention applied across all wavelet subbands
Methods:
| Name | Description |
|---|---|
forward |
Apply 2D wavelet-based mixing following the mathematical formulation. |
from_config |
Create WaveletMixing2D from configuration. |
Source code in spectrans/layers/mixing/wavelet.py
Functions¶
forward ¶
Apply 2D wavelet-based mixing following the mathematical formulation.
Implements complete 2D wavelet mixing: spatial decomposition → subband mixing → reconstruction → residual connection. Each channel is processed independently.
Mathematical Implementation
- Channel Extraction: \(\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]\) for each channel \(c\)
- 2D Wavelet Decomposition: \(\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}\)
- Subband Mixing: Apply mode-specific transformations to wavelet subbands
- 2D Reconstruction: \(\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}\)
- Channel Stacking: \(\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]\)
- Residual Connection: \(\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape \((B, C, H, W)\) where:
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Mixed output tensor of identical shape \((B, C, H, W)\) with 2D wavelet-domain mixing applied and residual connection. |
Notes
- Spatial dimensions preserved through careful reconstruction handling
- Different mixing strategies provide various inductive biases
- Subband mode: Independent processing emphasizes local features
- Cross mode: Attention enables global subband interactions
- Attention mode: Full transformer encoder for complex dependencies
Source code in spectrans/layers/mixing/wavelet.py
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 | |
from_config
classmethod
¶
from_config(config: WaveletMixing2DConfig) -> WaveletMixing2D
Create WaveletMixing2D from configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WaveletMixing2DConfig
|
Typed and validated configuration. |
required |
Returns:
| Type | Description |
|---|---|
WaveletMixing2D
|
Configured instance. |