Skip to content

Wavelet Transforms

spectrans.transforms.wavelet

PyWavelets-compatible Discrete Wavelet Transform implementations.

This module provides DWT implementations that exactly match PyWavelets behavior while maintaining full gradient support for PyTorch.

Classes:

Name Description
DWT1D

1D Discrete Wavelet Transform with multi-level support.

DWT2D

2D Discrete Wavelet Transform using separable 1D transforms.

Functions:

Name Description
get_wavelet_filters

Extract filter coefficients from PyWavelets.

Examples:

Basic 1D wavelet transform:

>>> import torch
>>> from spectrans.transforms.wavelet import DWT1D
>>> dwt = DWT1D(wavelet='db4', levels=2)
>>> x = torch.randn(32, 256)
>>> cA, cD_list = dwt.decompose(x)
>>> x_rec = dwt.reconstruct((cA, cD_list))
>>> error = torch.max(torch.abs(x - x_rec))
>>> print(f"Reconstruction error: {error:.2e}")  # Should be < 1e-6

2D wavelet transform for images:

>>> from spectrans.transforms.wavelet import DWT2D
>>> dwt2d = DWT2D(wavelet='db2', levels=2)
>>> image = torch.randn(1, 64, 64)
>>> ll, detail_bands = dwt2d.decompose(image)
>>> reconstructed = dwt2d.reconstruct((ll, detail_bands))

Multi-level decomposition with energy analysis:

>>> dwt = DWT1D(wavelet='db4', levels=3)
>>> x = torch.randn(1, 512)
>>> cA, cD_list = dwt.decompose(x)
>>> # Verify Parseval's theorem for orthogonal wavelets
>>> energy_input = torch.sum(x ** 2)
>>> energy_coeffs = torch.sum(cA ** 2) + sum(torch.sum(cD ** 2) for cD in cD_list)
>>> print(f"Energy ratio: {energy_coeffs / energy_input:.6f}")  # Should be ≈ 1.0
Notes

Mathematical Foundations: The Discrete Wavelet Transform (DWT) decomposes a signal into approximation and detail coefficients through iterative filtering and downsampling.

For a signal \(\mathbf{x}[n]\) of length \(N\), the single-level DWT produces:

\[ c_A[k] = \sum_{n} h[n-2k] \cdot \mathbf{x}[n] \]
\[ c_D[k] = \sum_{n} g[n-2k] \cdot \mathbf{x}[n] \]

Where \(h[n]\) and \(g[n]\) are the low-pass and high-pass analysis filters. The reconstruction is achieved through:

\[ \mathbf{x}[n] = \sum_{k} h'[n-2k] \cdot c_A[k] + \sum_{k} g'[n-2k] \cdot c_D[k] \]

Multi-Resolution Analysis: The \(J\)-level DWT recursively applies the transform to approximation coefficients, creating a dyadic decomposition where each level \(j\) has length \(N/2^j\) and frequency band \([0, \pi/2^j]\) for approximations.

Perfect Reconstruction: For orthogonal wavelets: \(h'[n] = h[-n]\) and \(g'[n] = g[-n]\). The transform preserves energy: \(\|\mathbf{x}\|^2 = \|\mathbf{c}_A\|^2 + \sum_{j} \|\mathbf{c}_{D_j}\|^2\)

Implementation Details:

  • Convolution starts at index \((\text{step} - 1) = 1\) for stride 2
  • Symmetric mode reflects without edge repeat: [d,c,b,a | a,b,c,d | d,c,b,a]
  • Uses conv1d with flipped filters for correlation
  • IDWT uses conv_transpose1d with stride 2 for implicit upsampling
  • Output lengths follow PyWavelets formulas

Algorithm Complexity:

  • Forward/Inverse DWT: \(O(N)\) for \(N\)-length signal
  • Memory: \(O(N)\) for coefficients

Gradient Support: All operations use native PyTorch operations ensuring full autograd support.

Numerical Precision:

  • Filters use float64 for extraction, float32 for computation
  • Perfect reconstruction to \(\sim 10^{-7}\) for float32

Supported Wavelets: Daubechies (db1-db38), Symlets (sym2-sym20), Coiflets (coif1-coif17), Biorthogonal (bior/rbio), Discrete Meyer (dmey), Haar (haar)

References

Stéphane Mallat. 2009. A Wavelet Tour of Signal Processing: The Sparse Way, 3rd edition. Academic Press, Boston.

Ingrid Daubechies. 1992. Ten Lectures on Wavelets. SIAM, Philadelphia.

Gilbert Strang and Truong Nguyen. 1996. Wavelets and Filter Banks. Wellesley-Cambridge Press, Wellesley.

PyWavelets Development Team. 2024. PyWavelets: Wavelet transforms in Python. https://pywavelets.readthedocs.io/

See Also

spectrans.transforms.base : Base transform interfaces spectrans.layers.mixing.wavelet : Wavelet mixing layers spectrans.transforms.fourier : Fourier transform implementations

Classes

DWT1D

DWT1D(wavelet: WaveletType = 'db4', levels: int = 1, mode: str = 'symmetric')

Bases: MultiResolutionTransform

PyWavelets-compatible 1D Discrete Wavelet Transform.

This implementation exactly matches PyWavelets behavior based on comprehensive C code analysis. It supports multi-level decomposition and achieves perfect reconstruction (< 1e-6 error) for all wavelets.

Parameters:

Name Type Description Default
wavelet WaveletType

Wavelet type (e.g., 'db1', 'db2', 'db4', 'db8', 'sym2', 'coif1').

'db4'
levels int

Number of decomposition levels.

1
mode str

Boundary handling mode (currently only 'symmetric' supported).

'symmetric'

Attributes:

Name Type Description
wavelet str

The wavelet type being used.

levels int

Number of decomposition levels.

mode str

Boundary handling mode.

dec_lo Tensor

Low-pass decomposition filter.

dec_hi Tensor

High-pass decomposition filter.

rec_lo Tensor

Low-pass reconstruction filter.

rec_hi Tensor

High-pass reconstruction filter.

filter_length int

Length of the wavelet filters.

Examples:

>>> dwt = DWT1D(wavelet='db4', levels=3)
>>> x = torch.randn(16, 256)  # batch_size=16, length=256
>>> cA, cD_list = dwt.decompose(x)
>>> print(f"Approximation shape: {cA.shape}")
>>> print(f"Number of detail levels: {len(cD_list)}")
>>> x_rec = dwt.reconstruct((cA, cD_list))
>>> error = torch.max(torch.abs(x - x_rec))
>>> print(f"Reconstruction error: {error:.2e}")

Methods:

Name Description
decompose

Multi-level DWT decomposition.

reconstruct

Multi-level DWT reconstruction.

Source code in spectrans/transforms/wavelet.py
def __init__(self, wavelet: WaveletType = "db4", levels: int = 1, mode: str = "symmetric"):
    super().__init__(levels=levels)
    self.wavelet = wavelet
    self.mode = mode

    if mode != "symmetric":
        msg = f"Mode '{mode}' not yet supported. Only 'symmetric' is implemented."
        raise NotImplementedError(msg)

    # Get filter coefficients from PyWavelets
    dec_lo, dec_hi, rec_lo, rec_hi = get_wavelet_filters(wavelet)

    # Register as buffers (not parameters) since they're fixed
    # Convert to float32 for efficiency in neural networks
    self.register_buffer("dec_lo", dec_lo.float())
    self.register_buffer("dec_hi", dec_hi.float())
    self.register_buffer("rec_lo", rec_lo.float())
    self.register_buffer("rec_hi", rec_hi.float())

    self.filter_length = len(dec_lo)
Functions
decompose
decompose(x: Tensor, levels: int | None = None, dim: int = -1) -> tuple[Tensor, list[Tensor]]

Multi-level DWT decomposition.

Recursively applies DWT to approximation coefficients.

Parameters:

Name Type Description Default
x Tensor

Input signal.

required
levels int | None

Number of levels. If None, uses self.levels.

None
dim int

Dimension to decompose along.

-1

Returns:

Type Description
tuple[Tensor, list[Tensor]]

Tuple of (approximation, [detail_1, ..., detail_N]) where details are ordered from finest to coarsest.

Source code in spectrans/transforms/wavelet.py
def decompose(
    self, x: Tensor, levels: int | None = None, dim: int = -1
) -> tuple[Tensor, list[Tensor]]:
    """Multi-level DWT decomposition.

    Recursively applies DWT to approximation coefficients.

    Parameters
    ----------
    x : Tensor
        Input signal.
    levels : int | None
        Number of levels. If None, uses self.levels.
    dim : int
        Dimension to decompose along.

    Returns
    -------
    tuple[Tensor, list[Tensor]]
        Tuple of (approximation, [detail_1, ..., detail_N])
        where details are ordered from finest to coarsest.
    """
    if levels is None:
        levels = self.levels

    current = x
    details = []

    # Apply DWT recursively
    for _ in range(levels):
        cA, cD = self._single_dwt(current, dim=dim)
        details.append(cD)
        current = cA

    return current, details
reconstruct
reconstruct(coeffs: tuple[Tensor, list[Tensor]], dim: int = -1, output_len: int | None = None) -> Tensor

Multi-level DWT reconstruction.

Parameters:

Name Type Description Default
coeffs tuple[Tensor, list[Tensor]]

Tuple of (approximation, [detail_1, ..., detail_N]).

required
dim int

Dimension to reconstruct along.

-1
output_len int | None

Desired output length. If provided, the reconstructed signal will be trimmed or padded to this length.

None

Returns:

Type Description
Tensor

Reconstructed signal.

Source code in spectrans/transforms/wavelet.py
def reconstruct(
    self, coeffs: tuple[Tensor, list[Tensor]], dim: int = -1, output_len: int | None = None
) -> Tensor:
    """Multi-level DWT reconstruction.

    Parameters
    ----------
    coeffs : tuple[Tensor, list[Tensor]]
        Tuple of (approximation, [detail_1, ..., detail_N]).
    dim : int
        Dimension to reconstruct along.
    output_len : int | None
        Desired output length. If provided, the reconstructed signal
        will be trimmed or padded to this length.

    Returns
    -------
    Tensor
        Reconstructed signal.
    """
    cA, details = coeffs
    current = cA

    # Reconstruct from coarsest to finest (reverse order)
    for _i, cD in enumerate(reversed(details)):
        # For multi-level, we need to handle size mismatches
        # The reconstructed signal from a coarser level may be slightly
        # longer than the detail coefficients from the finer level

        if current.shape[dim] > cD.shape[dim]:
            # Trim current to match cD size
            # This happens because IDWT can produce slightly longer output
            target_len = cD.shape[dim]
            if dim == -1 or dim == current.ndim - 1:
                current = current[..., :target_len]
            elif dim == 0:
                current = current[:target_len]
            else:
                # General case
                slices = [slice(None)] * current.ndim
                slices[dim] = slice(0, target_len)
                current = current[tuple(slices)]
        elif current.shape[dim] < cD.shape[dim]:
            # This shouldn't happen with correct decomposition
            raise ValueError(
                f"Reconstructed signal smaller than detail coefficients. "
                f"current shape: {current.shape}, cD shape: {cD.shape}"
            )

        # Now they should have matching sizes
        current = self._single_idwt(current, cD, dim=dim)

    # Trim to desired output length if specified
    if output_len is not None and current.shape[dim] != output_len:
        if dim == -1 or dim == current.ndim - 1:
            current = current[..., :output_len]
        elif dim == 0:
            current = current[:output_len]
        else:
            slices = [slice(None)] * current.ndim
            slices[dim] = slice(0, output_len)
            current = current[tuple(slices)]

    return current

DWT2D

DWT2D(wavelet: WaveletType = 'db4', levels: int = 1, mode: str = 'symmetric')

Bases: MultiResolutionTransform2D

PyWavelets-compatible 2D Discrete Wavelet Transform.

Implements 2D DWT using separable 1D transforms, applying DWT along each dimension sequentially. Returns coefficients in the standard format: (LL, [(LH, HL, HH) per level]).

Parameters:

Name Type Description Default
wavelet WaveletType

Wavelet type to use.

'db4'
levels int

Number of decomposition levels.

1
mode str

Boundary handling mode.

'symmetric'

Attributes:

Name Type Description
wavelet str

The wavelet type.

levels int

Number of decomposition levels.

mode str

Boundary handling mode.

dwt1d DWT1D

1D DWT instance used for separable transforms.

Examples:

>>> dwt2d = DWT2D(wavelet='db2', levels=2)
>>> image = torch.randn(4, 64, 64)  # batch of 4 images
>>> ll, detail_bands = dwt2d.decompose(image)
>>> print(f"LL shape: {ll.shape}")
>>> for i, (lh, hl, hh) in enumerate(detail_bands):
...     print(f"Level {i+1} - LH: {lh.shape}, HL: {hl.shape}, HH: {hh.shape}")
>>> reconstructed = dwt2d.reconstruct((ll, detail_bands))

Methods:

Name Description
decompose

Multi-level 2D DWT decomposition.

reconstruct

Multi-level 2D DWT reconstruction.

Source code in spectrans/transforms/wavelet.py
def __init__(self, wavelet: WaveletType = "db4", levels: int = 1, mode: str = "symmetric"):
    super().__init__(levels=levels)
    self.wavelet = wavelet
    self.mode = mode

    # Use 1D DWT for separable 2D transform
    self.dwt1d = DWT1D(wavelet=wavelet, levels=1, mode=mode)
Functions
decompose
decompose(x: Tensor, levels: int | None = None, dim: tuple[int, int] = (-2, -1)) -> tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]

Multi-level 2D DWT decomposition.

Parameters:

Name Type Description Default
x Tensor

Input 2D tensor.

required
levels int | None

Number of levels. If None, uses self.levels.

None
dim tuple[int, int]

Dimensions to decompose along.

(-2, -1)

Returns:

Type Description
tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]

Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention where HL is horizontal detail, LH is vertical detail, HH is diagonal detail.

Source code in spectrans/transforms/wavelet.py
def decompose(
    self, x: Tensor, levels: int | None = None, dim: tuple[int, int] = (-2, -1)
) -> tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]:
    """Multi-level 2D DWT decomposition.

    Parameters
    ----------
    x : Tensor
        Input 2D tensor.
    levels : int | None
        Number of levels. If None, uses self.levels.
    dim : tuple[int, int]
        Dimensions to decompose along.

    Returns
    -------
    tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
        Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention
        where HL is horizontal detail, LH is vertical detail, HH is diagonal detail.
    """
    if levels is None:
        levels = self.levels

    current = x
    detail_bands = []

    for _ in range(levels):
        ll, lh, hl, hh = self._single_level_2d(current, dim=dim)
        # PyWavelets returns (cH, cV, cD) = (HL, LH, HH)
        # So we append (HL, LH, HH) to match
        detail_bands.append((hl, lh, hh))
        current = ll

    return current, detail_bands
reconstruct
reconstruct(coeffs: tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]], dim: tuple[int, int] = (-2, -1)) -> Tensor

Multi-level 2D DWT reconstruction.

Parameters:

Name Type Description Default
coeffs tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]

Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention.

required
dim tuple[int, int]

Dimensions to reconstruct along.

(-2, -1)

Returns:

Type Description
Tensor

Reconstructed 2D tensor.

Source code in spectrans/transforms/wavelet.py
def reconstruct(
    self,
    coeffs: tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]],
    dim: tuple[int, int] = (-2, -1),
) -> Tensor:
    """Multi-level 2D DWT reconstruction.

    Parameters
    ----------
    coeffs : tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
        Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention.
    dim : tuple[int, int]
        Dimensions to reconstruct along.

    Returns
    -------
    Tensor
        Reconstructed 2D tensor.
    """
    ll, detail_bands = coeffs
    current = ll

    # Reconstruct from coarsest to finest
    # detail_bands contains (HL, LH, HH) tuples following PyWavelets convention
    for hl, lh, hh in reversed(detail_bands):
        current = self._single_level_2d_reconstruct(current, lh, hl, hh, dim=dim)

    return current

Functions

get_wavelet_filters

get_wavelet_filters(wavelet_name: str) -> tuple[Tensor, Tensor, Tensor, Tensor]

Get filter coefficients from PyWavelets.

Parameters:

Name Type Description Default
wavelet_name str

Name of the wavelet (e.g., 'db1', 'db2', 'db4', 'sym2').

required

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, Tensor]

Tuple of (dec_lo, dec_hi, rec_lo, rec_hi) filter tensors.

Raises:

Type Description
ValueError

If wavelet is not supported by PyWavelets.

Source code in spectrans/transforms/wavelet.py
def get_wavelet_filters(wavelet_name: str) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    """Get filter coefficients from PyWavelets.

    Parameters
    ----------
    wavelet_name : str
        Name of the wavelet (e.g., 'db1', 'db2', 'db4', 'sym2').

    Returns
    -------
    tuple[Tensor, Tensor, Tensor, Tensor]
        Tuple of (dec_lo, dec_hi, rec_lo, rec_hi) filter tensors.

    Raises
    ------
    ValueError
        If wavelet is not supported by PyWavelets.
    """
    try:
        wavelet = pywt.Wavelet(wavelet_name)
    except ValueError as e:
        msg = f"Unsupported wavelet: {wavelet_name}"
        raise ValueError(msg) from e

    # Extract filters exactly as PyWavelets provides them
    # Use float64 for maximum precision compatibility with PyWavelets
    dec_lo = torch.tensor(wavelet.dec_lo, dtype=torch.float64)
    dec_hi = torch.tensor(wavelet.dec_hi, dtype=torch.float64)
    rec_lo = torch.tensor(wavelet.rec_lo, dtype=torch.float64)
    rec_hi = torch.tensor(wavelet.rec_hi, dtype=torch.float64)

    return dec_lo, dec_hi, rec_lo, rec_hi