Wavelet Transforms¶
spectrans.transforms.wavelet ¶
PyWavelets-compatible Discrete Wavelet Transform implementations.
This module provides DWT implementations that exactly match PyWavelets behavior while maintaining full gradient support for PyTorch.
Classes:
| Name | Description |
|---|---|
DWT1D |
1D Discrete Wavelet Transform with multi-level support. |
DWT2D |
2D Discrete Wavelet Transform using separable 1D transforms. |
Functions:
| Name | Description |
|---|---|
get_wavelet_filters |
Extract filter coefficients from PyWavelets. |
Examples:
Basic 1D wavelet transform:
>>> import torch
>>> from spectrans.transforms.wavelet import DWT1D
>>> dwt = DWT1D(wavelet='db4', levels=2)
>>> x = torch.randn(32, 256)
>>> cA, cD_list = dwt.decompose(x)
>>> x_rec = dwt.reconstruct((cA, cD_list))
>>> error = torch.max(torch.abs(x - x_rec))
>>> print(f"Reconstruction error: {error:.2e}") # Should be < 1e-6
2D wavelet transform for images:
>>> from spectrans.transforms.wavelet import DWT2D
>>> dwt2d = DWT2D(wavelet='db2', levels=2)
>>> image = torch.randn(1, 64, 64)
>>> ll, detail_bands = dwt2d.decompose(image)
>>> reconstructed = dwt2d.reconstruct((ll, detail_bands))
Multi-level decomposition with energy analysis:
>>> dwt = DWT1D(wavelet='db4', levels=3)
>>> x = torch.randn(1, 512)
>>> cA, cD_list = dwt.decompose(x)
>>> # Verify Parseval's theorem for orthogonal wavelets
>>> energy_input = torch.sum(x ** 2)
>>> energy_coeffs = torch.sum(cA ** 2) + sum(torch.sum(cD ** 2) for cD in cD_list)
>>> print(f"Energy ratio: {energy_coeffs / energy_input:.6f}") # Should be ≈ 1.0
Notes
Mathematical Foundations: The Discrete Wavelet Transform (DWT) decomposes a signal into approximation and detail coefficients through iterative filtering and downsampling.
For a signal \(\mathbf{x}[n]\) of length \(N\), the single-level DWT produces:
Where \(h[n]\) and \(g[n]\) are the low-pass and high-pass analysis filters. The reconstruction is achieved through:
Multi-Resolution Analysis: The \(J\)-level DWT recursively applies the transform to approximation coefficients, creating a dyadic decomposition where each level \(j\) has length \(N/2^j\) and frequency band \([0, \pi/2^j]\) for approximations.
Perfect Reconstruction: For orthogonal wavelets: \(h'[n] = h[-n]\) and \(g'[n] = g[-n]\). The transform preserves energy: \(\|\mathbf{x}\|^2 = \|\mathbf{c}_A\|^2 + \sum_{j} \|\mathbf{c}_{D_j}\|^2\)
Implementation Details:
- Convolution starts at index \((\text{step} - 1) = 1\) for stride 2
- Symmetric mode reflects without edge repeat:
[d,c,b,a | a,b,c,d | d,c,b,a] - Uses
conv1dwith flipped filters for correlation - IDWT uses
conv_transpose1dwith stride 2 for implicit upsampling - Output lengths follow PyWavelets formulas
Algorithm Complexity:
- Forward/Inverse DWT: \(O(N)\) for \(N\)-length signal
- Memory: \(O(N)\) for coefficients
Gradient Support: All operations use native PyTorch operations ensuring full autograd support.
Numerical Precision:
- Filters use
float64for extraction,float32for computation - Perfect reconstruction to \(\sim 10^{-7}\) for
float32
Supported Wavelets:
Daubechies (db1-db38), Symlets (sym2-sym20),
Coiflets (coif1-coif17), Biorthogonal (bior/rbio),
Discrete Meyer (dmey), Haar (haar)
References
Stéphane Mallat. 2009. A Wavelet Tour of Signal Processing: The Sparse Way, 3rd edition. Academic Press, Boston.
Ingrid Daubechies. 1992. Ten Lectures on Wavelets. SIAM, Philadelphia.
Gilbert Strang and Truong Nguyen. 1996. Wavelets and Filter Banks. Wellesley-Cambridge Press, Wellesley.
PyWavelets Development Team. 2024. PyWavelets: Wavelet transforms in Python. https://pywavelets.readthedocs.io/
See Also
spectrans.transforms.base : Base transform interfaces spectrans.layers.mixing.wavelet : Wavelet mixing layers spectrans.transforms.fourier : Fourier transform implementations
Classes¶
DWT1D ¶
Bases: MultiResolutionTransform
PyWavelets-compatible 1D Discrete Wavelet Transform.
This implementation exactly matches PyWavelets behavior based on comprehensive C code analysis. It supports multi-level decomposition and achieves perfect reconstruction (< 1e-6 error) for all wavelets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
wavelet
|
WaveletType
|
Wavelet type (e.g., 'db1', 'db2', 'db4', 'db8', 'sym2', 'coif1'). |
'db4'
|
levels
|
int
|
Number of decomposition levels. |
1
|
mode
|
str
|
Boundary handling mode (currently only 'symmetric' supported). |
'symmetric'
|
Attributes:
| Name | Type | Description |
|---|---|---|
wavelet |
str
|
The wavelet type being used. |
levels |
int
|
Number of decomposition levels. |
mode |
str
|
Boundary handling mode. |
dec_lo |
Tensor
|
Low-pass decomposition filter. |
dec_hi |
Tensor
|
High-pass decomposition filter. |
rec_lo |
Tensor
|
Low-pass reconstruction filter. |
rec_hi |
Tensor
|
High-pass reconstruction filter. |
filter_length |
int
|
Length of the wavelet filters. |
Examples:
>>> dwt = DWT1D(wavelet='db4', levels=3)
>>> x = torch.randn(16, 256) # batch_size=16, length=256
>>> cA, cD_list = dwt.decompose(x)
>>> print(f"Approximation shape: {cA.shape}")
>>> print(f"Number of detail levels: {len(cD_list)}")
>>> x_rec = dwt.reconstruct((cA, cD_list))
>>> error = torch.max(torch.abs(x - x_rec))
>>> print(f"Reconstruction error: {error:.2e}")
Methods:
| Name | Description |
|---|---|
decompose |
Multi-level DWT decomposition. |
reconstruct |
Multi-level DWT reconstruction. |
Source code in spectrans/transforms/wavelet.py
Functions¶
decompose ¶
Multi-level DWT decomposition.
Recursively applies DWT to approximation coefficients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input signal. |
required |
levels
|
int | None
|
Number of levels. If None, uses self.levels. |
None
|
dim
|
int
|
Dimension to decompose along. |
-1
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, list[Tensor]]
|
Tuple of (approximation, [detail_1, ..., detail_N]) where details are ordered from finest to coarsest. |
Source code in spectrans/transforms/wavelet.py
reconstruct ¶
reconstruct(coeffs: tuple[Tensor, list[Tensor]], dim: int = -1, output_len: int | None = None) -> Tensor
Multi-level DWT reconstruction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coeffs
|
tuple[Tensor, list[Tensor]]
|
Tuple of (approximation, [detail_1, ..., detail_N]). |
required |
dim
|
int
|
Dimension to reconstruct along. |
-1
|
output_len
|
int | None
|
Desired output length. If provided, the reconstructed signal will be trimmed or padded to this length. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed signal. |
Source code in spectrans/transforms/wavelet.py
DWT2D ¶
Bases: MultiResolutionTransform2D
PyWavelets-compatible 2D Discrete Wavelet Transform.
Implements 2D DWT using separable 1D transforms, applying DWT along each dimension sequentially. Returns coefficients in the standard format: (LL, [(LH, HL, HH) per level]).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
wavelet
|
WaveletType
|
Wavelet type to use. |
'db4'
|
levels
|
int
|
Number of decomposition levels. |
1
|
mode
|
str
|
Boundary handling mode. |
'symmetric'
|
Attributes:
| Name | Type | Description |
|---|---|---|
wavelet |
str
|
The wavelet type. |
levels |
int
|
Number of decomposition levels. |
mode |
str
|
Boundary handling mode. |
dwt1d |
DWT1D
|
1D DWT instance used for separable transforms. |
Examples:
>>> dwt2d = DWT2D(wavelet='db2', levels=2)
>>> image = torch.randn(4, 64, 64) # batch of 4 images
>>> ll, detail_bands = dwt2d.decompose(image)
>>> print(f"LL shape: {ll.shape}")
>>> for i, (lh, hl, hh) in enumerate(detail_bands):
... print(f"Level {i+1} - LH: {lh.shape}, HL: {hl.shape}, HH: {hh.shape}")
>>> reconstructed = dwt2d.reconstruct((ll, detail_bands))
Methods:
| Name | Description |
|---|---|
decompose |
Multi-level 2D DWT decomposition. |
reconstruct |
Multi-level 2D DWT reconstruction. |
Source code in spectrans/transforms/wavelet.py
Functions¶
decompose ¶
decompose(x: Tensor, levels: int | None = None, dim: tuple[int, int] = (-2, -1)) -> tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
Multi-level 2D DWT decomposition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input 2D tensor. |
required |
levels
|
int | None
|
Number of levels. If None, uses self.levels. |
None
|
dim
|
tuple[int, int]
|
Dimensions to decompose along. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
|
Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention where HL is horizontal detail, LH is vertical detail, HH is diagonal detail. |
Source code in spectrans/transforms/wavelet.py
reconstruct ¶
reconstruct(coeffs: tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]], dim: tuple[int, int] = (-2, -1)) -> Tensor
Multi-level 2D DWT reconstruction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coeffs
|
tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
|
Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention. |
required |
dim
|
tuple[int, int]
|
Dimensions to reconstruct along. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed 2D tensor. |
Source code in spectrans/transforms/wavelet.py
Functions¶
get_wavelet_filters ¶
Get filter coefficients from PyWavelets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
wavelet_name
|
str
|
Name of the wavelet (e.g., 'db1', 'db2', 'db4', 'sym2'). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor, Tensor, Tensor]
|
Tuple of (dec_lo, dec_hi, rec_lo, rec_hi) filter tensors. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If wavelet is not supported by PyWavelets. |