Skip to content

Spectral Transforms

spectrans.transforms

Spectral transform implementations for neural networks.

This module provides implementations of spectral transforms used in spectral transformer architectures. All transforms implement consistent interfaces through the base classes, enabling easy substitution and experimentation with different spectral methods. The transforms support both real and complex inputs, batch processing, and multi-dimensional operations where applicable.

Modules:

Name Description
base

Base classes and interfaces for spectral transforms.

cosine

Discrete Cosine and Sine Transform implementations.

fourier

Fast Fourier Transform implementations.

hadamard

Hadamard and related orthogonal transforms.

wavelet

Discrete Wavelet Transform implementations.

Classes:

Name Description
AdaptiveTransform

Transform with learnable parameters for adaptation.

DCT

Discrete Cosine Transform implementation.

DCT2D

2D Discrete Cosine Transform for image-like data.

DST

Discrete Sine Transform implementation.

DWT1D

1D Discrete Wavelet Transform.

DWT2D

2D Discrete Wavelet Transform.

FFT1D

1D Fast Fourier Transform with real/complex support.

FFT2D

2D Fast Fourier Transform for AFNO-style operations.

HadamardTransform

Fast Hadamard Transform implementation.

HadamardTransform2D

2D Hadamard Transform implementation.

MDCT

Modified Discrete Cosine Transform for audio processing.

MultiResolutionTransform

Base class for multi-resolution decompositions.

NeuralSpectralTransform

Transform with neural network components.

OrthogonalTransform

Base class for orthogonal transforms (DCT, DST, Hadamard).

RFFT

Real-input Fast Fourier Transform.

RFFT2D

2D Real-input Fast Fourier Transform.

SequencyHadamardTransform

Sequency-ordered Hadamard transform.

SlantTransform

Slant transform implementation.

SpectralPooling

Spectral pooling operation in frequency domain.

SpectralTransform

Base class for simple 1D spectral transforms.

UnitaryTransform

Base class for unitary transforms (FFT).

Examples:

Using Fourier transforms:

>>> from spectrans.transforms import FFT1D, RFFT
>>> # Complex-input FFT
>>> fft = FFT1D()
>>> complex_output = fft.transform(complex_input)
>>> reconstructed = fft.inverse_transform(complex_output)
>>>
>>> # Real-input FFT
>>> rfft = RFFT()
>>> freq_domain = rfft.transform(real_input)

Using orthogonal transforms:

>>> from spectrans.transforms import DCT, HadamardTransform
>>> # Discrete Cosine Transform
>>> dct = DCT(normalized=True)
>>> dct_coeffs = dct.transform(signal)
>>>
>>> # Fast Hadamard Transform
>>> hadamard = HadamardTransform()
>>> hadamard_coeffs = hadamard.transform(signal, dim=-1)

Using wavelet transforms:

>>> from spectrans.transforms import DWT1D
>>> dwt = DWT1D(wavelet='db4', levels=3)
>>> approx_coeffs, detail_coeffs = dwt.decompose(signal)
>>> reconstructed = dwt.reconstruct((approx_coeffs, detail_coeffs))
Notes

Mathematical Properties:

The transforms maintain important mathematical properties:

  1. Orthogonal Transforms (DCT, DST, Hadamard):
  2. Preserve inner products: \(\langle \mathbf{x}, \mathbf{y} \rangle = \langle \mathcal{T}(\mathbf{x}), \mathcal{T}(\mathbf{y}) \rangle\)
  3. Perfect reconstruction: \(\mathcal{T}^{-1}(\mathcal{T}(\mathbf{x})) = \mathbf{x}\)
  4. Energy conservation (Parseval's theorem)

  5. Unitary Transforms (FFT):

  6. Complex inner product preservation
  7. Norm conservation: \(||\mathcal{T}(\mathbf{x})||_2 = ||\mathbf{x}||_2\)
  8. Hermitian symmetry for real inputs

  9. Multi-Resolution Transforms (DWT):

  10. Perfect reconstruction from coefficients
  11. Localization in both time and frequency
  12. Compact support for finite-length wavelets

Implementation Details:

  • All transforms support batch processing with proper broadcasting
  • Complex number operations use the spectrans.utils.complex module
  • Numerical stability is ensured through proper scaling and normalization
  • GPU acceleration through PyTorch's native FFT operations
  • In-place operations used where possible

Performance Characteristics:

  • FFT: \(O(n \log n)\) time complexity
  • DCT/DST: \(O(n \log n)\) via FFT-based algorithms
  • Hadamard: \(O(n \log n)\) fast transform algorithms
  • DWT: \(O(n)\) time complexity with compact support wavelets
See Also

spectrans.transforms.base : Base classes and interfaces. spectrans.utils.complex : Complex tensor operations. spectrans.core.registry : Component registration for transforms.

Classes

AdaptiveTransform

AdaptiveTransform(input_dim: int, learnable: bool = True)

Bases: NeuralSpectralTransform

Base class for adaptive transforms with learnable parameters.

Adaptive transforms can learn their basis functions or transformation parameters from data. This is useful for applications where the optimal spectral representation depends on the specific data distribution.

Parameters:

Name Type Description Default
input_dim int

Input dimension size.

required
learnable bool

Whether transform parameters are learnable.

True
Source code in spectrans/transforms/base.py
def __init__(self, input_dim: int, learnable: bool = True):
    super().__init__()
    self.input_dim = input_dim
    self.learnable = learnable

MultiResolutionTransform

MultiResolutionTransform(levels: int = 1)

Bases: Transform

Base class for multi-resolution transforms.

For transforms that decompose signals into multiple components at different resolution levels, such as Discrete Wavelet Transform (DWT).

These transforms are mathematically different from simple spectral transforms as they return multiple components: - Approximation coefficients at the coarsest level - Detail coefficients at each level

This matches the mathematical formulation: DWT(x) = {c_{A_J}, {c_{D_j}}_{j=1}^J}

Parameters:

Name Type Description Default
levels int

Number of decomposition levels.

1

Methods:

Name Description
decompose

Decompose signal into multiple resolution levels.

reconstruct

Reconstruct signal from multi-resolution coefficients.

Source code in spectrans/transforms/base.py
def __init__(self, levels: int = 1):
    super().__init__()
    self.levels = levels
Functions
decompose abstractmethod
decompose(x: Tensor, levels: int | None = None, dim: int = -1) -> tuple[Tensor, list[Tensor]]

Decompose signal into multiple resolution levels.

Parameters:

Name Type Description Default
x Tensor

Input tensor to decompose.

required
levels int | None

Number of levels. If None, use self.levels.

None
dim int

Dimension along which to apply decomposition.

-1

Returns:

Type Description
tuple[Tensor, list[Tensor]]

Tuple of (approximation_coefficients, detail_coefficients_list) where detail_coefficients_list contains coefficients from coarsest to finest level.

Source code in spectrans/transforms/base.py
@abstractmethod
def decompose(
    self, x: Tensor, levels: int | None = None, dim: int = -1
) -> tuple[Tensor, list[Tensor]]:
    """Decompose signal into multiple resolution levels.

    Parameters
    ----------
    x : Tensor
        Input tensor to decompose.
    levels : int | None, default=None
        Number of levels. If None, use self.levels.
    dim : int, default=-1
        Dimension along which to apply decomposition.

    Returns
    -------
    tuple[Tensor, list[Tensor]]
        Tuple of (approximation_coefficients, detail_coefficients_list)
        where detail_coefficients_list contains coefficients from
        coarsest to finest level.
    """
    pass
reconstruct abstractmethod
reconstruct(coeffs: tuple[Tensor, list[Tensor]], dim: int = -1) -> Tensor

Reconstruct signal from multi-resolution coefficients.

Parameters:

Name Type Description Default
coeffs tuple[Tensor, list[Tensor]]

Tuple of (approximation_coefficients, detail_coefficients_list).

required
dim int

Dimension along which to apply reconstruction.

-1

Returns:

Type Description
Tensor

Reconstructed tensor.

Source code in spectrans/transforms/base.py
@abstractmethod
def reconstruct(self, coeffs: tuple[Tensor, list[Tensor]], dim: int = -1) -> Tensor:
    """Reconstruct signal from multi-resolution coefficients.

    Parameters
    ----------
    coeffs : tuple[Tensor, list[Tensor]]
        Tuple of (approximation_coefficients, detail_coefficients_list).
    dim : int, default=-1
        Dimension along which to apply reconstruction.

    Returns
    -------
    Tensor
        Reconstructed tensor.
    """
    pass

NeuralSpectralTransform

Bases: SpectralTransform

Base class for learnable spectral transforms.

This class is for transforms that can learn their parameters during training, such as learnable filters in the frequency domain.

Methods:

Name Description
forward

Forward pass through the neural spectral transform.

Functions
forward
forward(x: Tensor) -> Tensor

Forward pass through the neural spectral transform.

By default, applies the transform operation. Subclasses can override this for more complex learned behaviors.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor.

Source code in spectrans/transforms/base.py
def forward(self, x: Tensor) -> Tensor:
    """Forward pass through the neural spectral transform.

    By default, applies the transform operation. Subclasses can
    override this for more complex learned behaviors.

    Parameters
    ----------
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Output tensor.
    """
    return self.transform(x)

OrthogonalTransform

Bases: SpectralTransform

Base class for orthogonal transforms.

Orthogonal transforms preserve inner products and have the property that their inverse is their transpose. This includes DCT, DST, and Hadamard transforms.

Attributes:

Name Type Description
is_orthogonal bool

Orthogonal transforms preserve inner products.

Attributes
is_orthogonal property
is_orthogonal: bool

Orthogonal transforms preserve inner products.

SpectralTransform

Bases: Transform

Base class for simple spectral transforms.

For transforms that map Tensor → Tensor along a specified dimension, such as FFT, DCT, DST, and Hadamard transforms. These transforms operate on a single dimension and return tensors of the same shape.

Mathematical operations supported: - Fourier transforms (FFT, RFFT) - Discrete Cosine Transform (DCT) - Discrete Sine Transform (DST) - Hadamard transform

Methods:

Name Description
transform

Apply forward transform along specified dimension.

inverse_transform

Apply inverse transform along specified dimension.

Attributes:

Name Type Description
is_orthogonal bool

Whether the transform is orthogonal.

is_unitary bool

Whether the transform is unitary.

Attributes
is_orthogonal property
is_orthogonal: bool

Whether the transform is orthogonal.

Returns:

Type Description
bool

True if the transform preserves inner products.

is_unitary property
is_unitary: bool

Whether the transform is unitary.

Returns:

Type Description
bool

True if the transform preserves complex inner products.

Functions
transform abstractmethod
transform(x: Tensor, dim: int = -1) -> Tensor

Apply forward transform along specified dimension.

Parameters:

Name Type Description Default
x Tensor

Input tensor to transform.

required
dim int

Dimension along which to apply the transform.

-1

Returns:

Type Description
Tensor

Transformed tensor with same shape as input.

Source code in spectrans/transforms/base.py
@abstractmethod
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply forward transform along specified dimension.

    Parameters
    ----------
    x : Tensor
        Input tensor to transform.
    dim : int, default=-1
        Dimension along which to apply the transform.

    Returns
    -------
    Tensor
        Transformed tensor with same shape as input.
    """
    pass
inverse_transform abstractmethod
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse transform along specified dimension.

Parameters:

Name Type Description Default
x Tensor

Transformed tensor to invert.

required
dim int

Dimension along which to apply the inverse transform.

-1

Returns:

Type Description
Tensor

Inverse transformed tensor with same shape as input.

Source code in spectrans/transforms/base.py
@abstractmethod
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse transform along specified dimension.

    Parameters
    ----------
    x : Tensor
        Transformed tensor to invert.
    dim : int, default=-1
        Dimension along which to apply the inverse transform.

    Returns
    -------
    Tensor
        Inverse transformed tensor with same shape as input.
    """
    pass

UnitaryTransform

Bases: SpectralTransform

Base class for unitary transforms.

Unitary transforms preserve complex inner products and have the property that their inverse is their conjugate transpose. This includes the Discrete Fourier Transform (DFT/FFT).

Attributes:

Name Type Description
is_unitary bool

Unitary transforms preserve complex inner products.

Attributes
is_unitary property
is_unitary: bool

Unitary transforms preserve complex inner products.

DCT

DCT(normalized: bool = True)

Bases: OrthogonalTransform

Discrete Cosine Transform (Type-II).

The DCT-II is the most commonly used DCT variant, often referred to as simply "the DCT". It's widely used in signal compression.

Parameters:

Name Type Description Default
normalized bool

Whether to use orthonormal normalization.

True

Methods:

Name Description
transform

Apply DCT-II transform.

inverse_transform

Apply inverse DCT (DCT-III).

Source code in spectrans/transforms/cosine.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
Functions
transform
transform(x: Tensor, dim: int = -1) -> Tensor

Apply DCT-II transform.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
dim int

Dimension along which to apply DCT.

-1

Returns:

Type Description
Tensor

DCT coefficients.

Source code in spectrans/transforms/cosine.py
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply DCT-II transform.

    Parameters
    ----------
    x : Tensor
        Input tensor.
    dim : int, default=-1
        Dimension along which to apply DCT.

    Returns
    -------
    Tensor
        DCT coefficients.
    """
    n = x.shape[dim]

    # Create DCT matrix
    dct_matrix = self._create_dct_matrix(n, x.device, x.dtype)

    # Apply DCT via matrix multiplication
    if dim == -1 or dim == x.ndim - 1:
        result = torch.matmul(x, dct_matrix.T)
    else:
        # Move dimension to last position
        x_moved = x.transpose(dim, -1)
        result = torch.matmul(x_moved, dct_matrix.T)
        result = result.transpose(dim, -1)

    return result
inverse_transform
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse DCT (DCT-III).

Parameters:

Name Type Description Default
x Tensor

DCT coefficients.

required
dim int

Dimension along which to apply inverse DCT.

-1

Returns:

Type Description
Tensor

Reconstructed signal.

Source code in spectrans/transforms/cosine.py
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse DCT (DCT-III).

    Parameters
    ----------
    x : Tensor
        DCT coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse DCT.

    Returns
    -------
    Tensor
        Reconstructed signal.
    """
    n = x.shape[dim]

    # Create inverse DCT matrix (DCT-III)
    idct_matrix = self._create_idct_matrix(n, x.device, x.dtype)

    # Apply inverse DCT via matrix multiplication
    if dim == -1 or dim == x.ndim - 1:
        result = torch.matmul(x, idct_matrix.T)
    else:
        # Move dimension to last position
        x_moved = x.transpose(dim, -1)
        result = torch.matmul(x_moved, idct_matrix.T)
        result = result.transpose(dim, -1)

    return result

DCT2D

DCT2D(normalized: bool = True)

Bases: SpectralTransform2D

2D Discrete Cosine Transform.

Applies DCT-II along both spatial dimensions, commonly used in image compression (e.g., JPEG).

Parameters:

Name Type Description Default
normalized bool

Whether to use orthonormal normalization.

True

Methods:

Name Description
transform

Apply 2D DCT.

inverse_transform

Apply inverse 2D DCT.

Source code in spectrans/transforms/cosine.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
    self.dct = DCT(normalized=normalized)
Functions
transform
transform(x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor

Apply 2D DCT.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
dim tuple[int, int]

Dimensions along which to apply 2D DCT.

(-2, -1)

Returns:

Type Description
Tensor

2D DCT coefficients.

Source code in spectrans/transforms/cosine.py
def transform(self, x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor:
    """Apply 2D DCT.

    Parameters
    ----------
    x : Tensor
        Input tensor.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply 2D DCT.

    Returns
    -------
    Tensor
        2D DCT coefficients.
    """
    # Apply DCT along first dimension
    result = self.dct.transform(x, dim=dim[0])
    # Apply DCT along second dimension
    result = self.dct.transform(result, dim=dim[1])
    return result
inverse_transform
inverse_transform(x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor

Apply inverse 2D DCT.

Parameters:

Name Type Description Default
x Tensor

2D DCT coefficients.

required
dim tuple[int, int]

Dimensions along which to apply inverse 2D DCT.

(-2, -1)

Returns:

Type Description
Tensor

Reconstructed signal.

Source code in spectrans/transforms/cosine.py
def inverse_transform(self, x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor:
    """Apply inverse 2D DCT.

    Parameters
    ----------
    x : Tensor
        2D DCT coefficients.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply inverse 2D DCT.

    Returns
    -------
    Tensor
        Reconstructed signal.
    """
    # Apply inverse DCT along second dimension
    result = self.dct.inverse_transform(x, dim=dim[1])
    # Apply inverse DCT along first dimension
    result = self.dct.inverse_transform(result, dim=dim[0])
    return result

DST

DST(normalized: bool = True)

Bases: OrthogonalTransform

Discrete Sine Transform (Type-II).

The DST-II is analogous to the DCT-II but uses sine functions.

Parameters:

Name Type Description Default
normalized bool

Whether to use orthonormal normalization.

True

Methods:

Name Description
transform

Apply DST-II transform.

inverse_transform

Apply inverse DST (DST-III).

Source code in spectrans/transforms/cosine.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
Functions
transform
transform(x: Tensor, dim: int = -1) -> Tensor

Apply DST-II transform.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
dim int

Dimension along which to apply DST.

-1

Returns:

Type Description
Tensor

DST coefficients.

Source code in spectrans/transforms/cosine.py
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply DST-II transform.

    Parameters
    ----------
    x : Tensor
        Input tensor.
    dim : int, default=-1
        Dimension along which to apply DST.

    Returns
    -------
    Tensor
        DST coefficients.
    """
    n = x.shape[dim]

    # Create DST matrix
    dst_matrix = self._create_dst_matrix(n, x.device, x.dtype)

    # Apply DST via matrix multiplication
    if dim == -1 or dim == x.ndim - 1:
        result = torch.matmul(x, dst_matrix.T)
    else:
        # Move dimension to last position
        x_moved = x.transpose(dim, -1)
        result = torch.matmul(x_moved, dst_matrix.T)
        result = result.transpose(dim, -1)

    return result
inverse_transform
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse DST (DST-III).

Parameters:

Name Type Description Default
x Tensor

DST coefficients.

required
dim int

Dimension along which to apply inverse DST.

-1

Returns:

Type Description
Tensor

Reconstructed signal.

Source code in spectrans/transforms/cosine.py
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse DST (DST-III).

    Parameters
    ----------
    x : Tensor
        DST coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse DST.

    Returns
    -------
    Tensor
        Reconstructed signal.
    """
    n = x.shape[dim]

    # Create inverse DST matrix (DST-III)
    idst_matrix = self._create_idst_matrix(n, x.device, x.dtype)

    # Apply inverse DST via matrix multiplication
    if dim == -1 or dim == x.ndim - 1:
        result = torch.matmul(x, idst_matrix.T)
    else:
        # Move dimension to last position
        x_moved = x.transpose(dim, -1)
        result = torch.matmul(x_moved, idst_matrix.T)
        result = result.transpose(dim, -1)

    return result

MDCT

MDCT(block_size: int, window: str = 'sine')

Bases: OrthogonalTransform

Modified Discrete Cosine Transform.

The MDCT is a lapped transform based on DCT-IV with 50% overlap, commonly used in audio compression (MP3, AAC).

Parameters:

Name Type Description Default
block_size int

Size of the transform block (must be even).

required
window str

Window function to use: "sine" or "vorbis".

"sine"

Methods:

Name Description
transform

Apply MDCT.

inverse_transform

Apply inverse MDCT.

Source code in spectrans/transforms/cosine.py
def __init__(self, block_size: int, window: str = "sine"):
    super().__init__()
    if block_size % 2 != 0:
        raise ValueError("Block size must be even for MDCT")

    self.block_size = block_size
    self.half_block = block_size // 2
    self.window_type = window
Functions
transform
transform(x: Tensor, dim: int = -1) -> Tensor

Apply MDCT.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Length along dim must be multiple of half_block.

required
dim int

Dimension along which to apply MDCT.

-1

Returns:

Type Description
Tensor

MDCT coefficients.

Source code in spectrans/transforms/cosine.py
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply MDCT.

    Parameters
    ----------
    x : Tensor
        Input tensor. Length along dim must be multiple of half_block.
    dim : int, default=-1
        Dimension along which to apply MDCT.

    Returns
    -------
    Tensor
        MDCT coefficients.
    """
    n = x.shape[dim]
    if n % self.half_block != 0:
        raise ValueError(f"Input length {n} must be multiple of {self.half_block}")

    # Number of blocks
    num_blocks = (n - self.half_block) // self.half_block

    # Get window
    window = self._get_window(self.block_size, x.device, x.dtype)

    # Prepare output
    output_shape = list(x.shape)
    output_shape[dim] = num_blocks * self.half_block
    output = torch.zeros(output_shape, device=x.device, dtype=x.dtype)

    # Process overlapping blocks
    for i in range(num_blocks):
        start = i * self.half_block
        end = start + self.block_size

        # Extract and window block
        if dim == -1:
            block = x[..., start:end] * window
        else:
            indices = torch.arange(start, end, device=x.device)
            block = torch.index_select(x, dim, indices)
            block = block * window.reshape([-1] + [1] * (x.ndim - dim - 1))

        # Apply DCT-IV (simplified using DCT-II)
        block_dct = self._dct4(block, dim=-1 if dim == -1 else dim)

        # Store result
        out_start = i * self.half_block
        out_end = out_start + self.half_block

        if dim == -1:
            output[..., out_start:out_end] = block_dct[..., : self.half_block]
        else:
            # Handle arbitrary dimension
            indices = torch.arange(out_start, out_end, device=x.device)
            output.index_copy_(
                dim,
                indices,
                torch.index_select(
                    block_dct, dim, torch.arange(self.half_block, device=x.device)
                ),
            )

    return output
inverse_transform
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse MDCT.

Parameters:

Name Type Description Default
x Tensor

MDCT coefficients.

required
dim int

Dimension along which to apply inverse MDCT.

-1

Returns:

Type Description
Tensor

Reconstructed signal with overlap-add.

Source code in spectrans/transforms/cosine.py
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse MDCT.

    Parameters
    ----------
    x : Tensor
        MDCT coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse MDCT.

    Returns
    -------
    Tensor
        Reconstructed signal with overlap-add.
    """
    # Inverse MDCT implementation would require overlap-add reconstruction
    # This is complex and beyond the scope of this basic implementation
    raise NotImplementedError("Inverse MDCT requires overlap-add reconstruction")

FFT1D

FFT1D(norm: FFTNorm = 'ortho')

Bases: UnitaryTransform

1D Fast Fourier Transform.

Applies 1D FFT along a specified dimension of the input tensor.

Parameters:

Name Type Description Default
norm FFTNorm

Normalization mode: "forward", "backward", or "ortho".

"ortho"

Methods:

Name Description
transform

Apply 1D FFT.

inverse_transform

Apply inverse 1D FFT.

Source code in spectrans/transforms/fourier.py
def __init__(self, norm: FFTNorm = "ortho"):
    self.norm = norm
Functions
transform
transform(x: Tensor, dim: int = -1) -> ComplexTensor

Apply 1D FFT.

Parameters:

Name Type Description Default
x Tensor

Input tensor of real or complex values.

required
dim int

Dimension along which to apply FFT.

-1

Returns:

Type Description
ComplexTensor

Complex-valued FFT result.

Source code in spectrans/transforms/fourier.py
def transform(self, x: Tensor, dim: int = -1) -> ComplexTensor:
    """Apply 1D FFT.

    Parameters
    ----------
    x : Tensor
        Input tensor of real or complex values.
    dim : int, default=-1
        Dimension along which to apply FFT.

    Returns
    -------
    ComplexTensor
        Complex-valued FFT result.
    """
    return safe_fft(x, dim=dim, norm=self.norm)
inverse_transform
inverse_transform(x: ComplexTensor, dim: int = -1) -> Tensor

Apply inverse 1D FFT.

Parameters:

Name Type Description Default
x ComplexTensor

Complex-valued FFT coefficients.

required
dim int

Dimension along which to apply inverse FFT.

-1

Returns:

Type Description
Tensor

Inverse FFT result (may be complex if input was complex).

Source code in spectrans/transforms/fourier.py
def inverse_transform(self, x: ComplexTensor, dim: int = -1) -> Tensor:
    """Apply inverse 1D FFT.

    Parameters
    ----------
    x : ComplexTensor
        Complex-valued FFT coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse FFT.

    Returns
    -------
    Tensor
        Inverse FFT result (may be complex if input was complex).
    """
    return safe_ifft(x, dim=dim, norm=self.norm)

FFT2D

FFT2D(norm: FFTNorm = 'ortho')

Bases: SpectralTransform2D

2D Fast Fourier Transform.

Applies 2D FFT along the last two dimensions of the input tensor.

Parameters:

Name Type Description Default
norm FFTNorm

Normalization mode: "forward", "backward", or "ortho".

"ortho"

Methods:

Name Description
transform

Apply 2D FFT.

inverse_transform

Apply inverse 2D FFT.

Source code in spectrans/transforms/fourier.py
def __init__(self, norm: FFTNorm = "ortho"):
    self.norm = norm
Functions
transform
transform(x: Tensor, dim: tuple[int, int] = (-2, -1)) -> ComplexTensor

Apply 2D FFT.

Parameters:

Name Type Description Default
x Tensor

Input tensor of real or complex values.

required
dim tuple[int, int]

Dimensions along which to apply 2D FFT.

(-2, -1)

Returns:

Type Description
ComplexTensor

Complex-valued 2D FFT result.

Source code in spectrans/transforms/fourier.py
def transform(self, x: Tensor, dim: tuple[int, int] = (-2, -1)) -> ComplexTensor:
    """Apply 2D FFT.

    Parameters
    ----------
    x : Tensor
        Input tensor of real or complex values.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply 2D FFT.

    Returns
    -------
    ComplexTensor
        Complex-valued 2D FFT result.
    """
    return safe_fft2(x, dim=dim, norm=self.norm)
inverse_transform
inverse_transform(x: ComplexTensor, dim: tuple[int, int] = (-2, -1)) -> Tensor

Apply inverse 2D FFT.

Parameters:

Name Type Description Default
x ComplexTensor

Complex-valued FFT coefficients.

required
dim tuple[int, int]

Dimensions along which to apply inverse FFT.

(-2, -1)

Returns:

Type Description
Tensor

Inverse FFT result.

Source code in spectrans/transforms/fourier.py
def inverse_transform(self, x: ComplexTensor, dim: tuple[int, int] = (-2, -1)) -> Tensor:
    """Apply inverse 2D FFT.

    Parameters
    ----------
    x : ComplexTensor
        Complex-valued FFT coefficients.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply inverse FFT.

    Returns
    -------
    Tensor
        Inverse FFT result.
    """
    return safe_ifft2(x, dim=dim, norm=self.norm)

RFFT

RFFT(norm: FFTNorm = 'ortho')

Bases: UnitaryTransform

Real Fast Fourier Transform.

Applies FFT to real-valued inputs, returning only the positive frequency components.

Parameters:

Name Type Description Default
norm FFTNorm

Normalization mode: "forward", "backward", or "ortho".

"ortho"

Methods:

Name Description
transform

Apply real FFT.

inverse_transform

Apply inverse real FFT.

Source code in spectrans/transforms/fourier.py
def __init__(self, norm: FFTNorm = "ortho"):
    self.norm = norm
Functions
transform
transform(x: Tensor, dim: int = -1) -> ComplexTensor

Apply real FFT.

Parameters:

Name Type Description Default
x Tensor

Real-valued input tensor.

required
dim int

Dimension along which to apply RFFT.

-1

Returns:

Type Description
ComplexTensor

Complex-valued RFFT result (positive frequencies only).

Source code in spectrans/transforms/fourier.py
def transform(self, x: Tensor, dim: int = -1) -> ComplexTensor:
    """Apply real FFT.

    Parameters
    ----------
    x : Tensor
        Real-valued input tensor.
    dim : int, default=-1
        Dimension along which to apply RFFT.

    Returns
    -------
    ComplexTensor
        Complex-valued RFFT result (positive frequencies only).
    """
    return safe_rfft(x, dim=dim, norm=self.norm)
inverse_transform
inverse_transform(x: ComplexTensor, dim: int = -1, n: int | None = None) -> Tensor

Apply inverse real FFT.

Parameters:

Name Type Description Default
x ComplexTensor

Complex-valued RFFT coefficients.

required
dim int

Dimension along which to apply inverse RFFT.

-1
n int | None

Length of the output signal. If None, inferred from input.

None

Returns:

Type Description
Tensor

Real-valued inverse RFFT result.

Source code in spectrans/transforms/fourier.py
def inverse_transform(self, x: ComplexTensor, dim: int = -1, n: int | None = None) -> Tensor:
    """Apply inverse real FFT.

    Parameters
    ----------
    x : ComplexTensor
        Complex-valued RFFT coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse RFFT.
    n : int | None, default=None
        Length of the output signal. If None, inferred from input.

    Returns
    -------
    Tensor
        Real-valued inverse RFFT result.
    """
    return safe_irfft(x, n=n, dim=dim, norm=self.norm)

RFFT2D

RFFT2D(norm: FFTNorm = 'ortho')

Bases: SpectralTransform2D

2D Real Fast Fourier Transform.

Applies 2D FFT to real-valued inputs.

Parameters:

Name Type Description Default
norm FFTNorm

Normalization mode: "forward", "backward", or "ortho".

"ortho"

Methods:

Name Description
transform

Apply 2D real FFT.

inverse_transform

Apply inverse 2D real FFT.

Source code in spectrans/transforms/fourier.py
def __init__(self, norm: FFTNorm = "ortho"):
    self.norm = norm
Functions
transform
transform(x: Tensor, dim: tuple[int, int] = (-2, -1)) -> ComplexTensor

Apply 2D real FFT.

Parameters:

Name Type Description Default
x Tensor

Real-valued input tensor.

required
dim tuple[int, int]

Dimensions along which to apply 2D RFFT.

(-2, -1)

Returns:

Type Description
ComplexTensor

Complex-valued 2D RFFT result.

Source code in spectrans/transforms/fourier.py
def transform(self, x: Tensor, dim: tuple[int, int] = (-2, -1)) -> ComplexTensor:
    """Apply 2D real FFT.

    Parameters
    ----------
    x : Tensor
        Real-valued input tensor.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply 2D RFFT.

    Returns
    -------
    ComplexTensor
        Complex-valued 2D RFFT result.
    """
    return safe_rfft2(x, dim=dim, norm=self.norm)
inverse_transform
inverse_transform(x: ComplexTensor, dim: tuple[int, int] = (-2, -1), s: tuple[int, int] | None = None) -> Tensor

Apply inverse 2D real FFT.

Parameters:

Name Type Description Default
x ComplexTensor

Complex-valued RFFT coefficients.

required
dim tuple[int, int]

Dimensions along which to apply inverse RFFT.

(-2, -1)
s tuple[int, int] | None

Output signal size. If None, inferred from input.

None

Returns:

Type Description
Tensor

Real-valued inverse RFFT result.

Source code in spectrans/transforms/fourier.py
def inverse_transform(
    self, x: ComplexTensor, dim: tuple[int, int] = (-2, -1), s: tuple[int, int] | None = None
) -> Tensor:
    """Apply inverse 2D real FFT.

    Parameters
    ----------
    x : ComplexTensor
        Complex-valued RFFT coefficients.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply inverse RFFT.
    s : tuple[int, int] | None, default=None
        Output signal size. If None, inferred from input.

    Returns
    -------
    Tensor
        Real-valued inverse RFFT result.
    """
    return safe_irfft2(x, s=s, dim=dim, norm=self.norm)

SpectralPooling

SpectralPooling(output_size: int | tuple[int, ...], norm: FFTNorm = 'ortho')

Bases: UnitaryTransform

Spectral pooling via frequency domain truncation.

Reduces spatial dimensions by truncating high-frequency components in the Fourier domain.

Parameters:

Name Type Description Default
output_size int | tuple[int, ...]

Target output size after pooling.

required
norm FFTNorm

Normalization mode for FFT operations.

"ortho"

Methods:

Name Description
transform

Apply spectral pooling.

inverse_transform

Inverse is not well-defined for pooling operations.

Source code in spectrans/transforms/fourier.py
def __init__(self, output_size: int | tuple[int, ...], norm: FFTNorm = "ortho"):
    self.output_size = output_size if isinstance(output_size, tuple) else (output_size,)
    self.norm = norm
Functions
transform
transform(x: Tensor, dim: int | tuple[int, ...] = -1) -> Tensor

Apply spectral pooling.

Parameters:

Name Type Description Default
x Tensor

Input tensor to pool.

required
dim int | tuple[int, ...]

Dimensions to pool along.

-1

Returns:

Type Description
Tensor

Spectrally pooled tensor.

Source code in spectrans/transforms/fourier.py
def transform(self, x: Tensor, dim: int | tuple[int, ...] = -1) -> Tensor:
    """Apply spectral pooling.

    Parameters
    ----------
    x : Tensor
        Input tensor to pool.
    dim : int | tuple[int, ...], default=-1
        Dimensions to pool along.

    Returns
    -------
    Tensor
        Spectrally pooled tensor.
    """
    # Convert to frequency domain
    if isinstance(dim, int):
        x_freq = safe_rfft(x, dim=dim, norm=self.norm)
    else:
        x_freq = safe_rfftn(x, dim=dim, norm=self.norm)

    # Truncate frequencies
    if isinstance(dim, int):
        truncated = x_freq[..., : self.output_size[0] // 2 + 1]
    else:
        # Handle multi-dimensional truncation
        slices = [slice(None)] * x_freq.ndim
        for i, d in enumerate(dim):
            size = self.output_size[i] if i < len(self.output_size) else x_freq.shape[d]
            slices[d] = slice(0, size // 2 + 1) if d == dim[-1] else slice(0, size)
        truncated = x_freq[tuple(slices)]

    # Convert back to spatial domain
    if isinstance(dim, int):
        return safe_irfft(truncated, n=self.output_size[0], dim=dim, norm=self.norm)
    else:
        return safe_irfftn(truncated, s=self.output_size, dim=dim, norm=self.norm)
inverse_transform
inverse_transform(x: Tensor, dim: int | tuple[int, ...] = -1) -> Tensor

Inverse is not well-defined for pooling operations.

Source code in spectrans/transforms/fourier.py
def inverse_transform(self, x: Tensor, dim: int | tuple[int, ...] = -1) -> Tensor:
    """Inverse is not well-defined for pooling operations."""
    raise NotImplementedError("Spectral pooling is not invertible due to information loss")

HadamardTransform

HadamardTransform(normalized: bool = True)

Bases: OrthogonalTransform

Fast Walsh-Hadamard Transform.

The Hadamard transform is an orthogonal transform using only +1 and -1 values. The transform size must be a power of 2.

Parameters:

Name Type Description Default
normalized bool

Whether to normalize by 1/sqrt(n) for orthogonality.

True

Methods:

Name Description
transform

Apply Fast Walsh-Hadamard Transform.

inverse_transform

Apply inverse Hadamard transform.

Source code in spectrans/transforms/hadamard.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
Functions
transform
transform(x: Tensor, dim: int = -1) -> Tensor

Apply Fast Walsh-Hadamard Transform.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Size along dim must be power of 2.

required
dim int

Dimension along which to apply transform.

-1

Returns:

Type Description
Tensor

Hadamard transformed tensor.

Raises:

Type Description
ValueError

If size along dim is not a power of 2.

Source code in spectrans/transforms/hadamard.py
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply Fast Walsh-Hadamard Transform.

    Parameters
    ----------
    x : Tensor
        Input tensor. Size along dim must be power of 2.
    dim : int, default=-1
        Dimension along which to apply transform.

    Returns
    -------
    Tensor
        Hadamard transformed tensor.

    Raises
    ------
    ValueError
        If size along dim is not a power of 2.
    """
    n = x.shape[dim]

    # Check if n is power of 2
    if n & (n - 1) != 0:
        raise ValueError(f"Hadamard transform requires size to be power of 2, got {n}")

    # Move dimension to last for easier processing
    if dim != -1 and dim != x.ndim - 1:
        x = x.transpose(dim, -1)

    # Apply Fast Walsh-Hadamard Transform
    result = self._fwht(x)

    # Normalize if requested
    if self.normalized:
        result = result / math.sqrt(n)

    # Move dimension back
    if dim != -1 and dim != x.ndim - 1:
        result = result.transpose(dim, -1)

    return result
inverse_transform
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse Hadamard transform.

The Hadamard transform is self-inverse (up to normalization).

Parameters:

Name Type Description Default
x Tensor

Hadamard coefficients.

required
dim int

Dimension along which to apply inverse transform.

-1

Returns:

Type Description
Tensor

Inverse transformed tensor.

Source code in spectrans/transforms/hadamard.py
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse Hadamard transform.

    The Hadamard transform is self-inverse (up to normalization).

    Parameters
    ----------
    x : Tensor
        Hadamard coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse transform.

    Returns
    -------
    Tensor
        Inverse transformed tensor.
    """
    n = x.shape[dim]

    # For orthogonal Hadamard, inverse is same as forward
    if self.normalized:
        return self.transform(x, dim)
    else:
        # Without normalization, need to scale by 1/n
        result = self.transform(x, dim)
        return result / n

HadamardTransform2D

HadamardTransform2D(normalized: bool = True)

Bases: SpectralTransform2D

2D Fast Walsh-Hadamard Transform.

Applies Hadamard transform along two dimensions.

Parameters:

Name Type Description Default
normalized bool

Whether to normalize for orthogonality.

True

Methods:

Name Description
transform

Apply 2D Hadamard transform.

inverse_transform

Apply inverse 2D Hadamard transform.

Source code in spectrans/transforms/hadamard.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
    self.hadamard = HadamardTransform(normalized=False)  # Handle normalization here
Functions
transform
transform(x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor

Apply 2D Hadamard transform.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Sizes along both dims must be powers of 2.

required
dim tuple[int, int]

Dimensions along which to apply transform.

(-2, -1)

Returns:

Type Description
Tensor

2D Hadamard transformed tensor.

Source code in spectrans/transforms/hadamard.py
def transform(self, x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor:
    """Apply 2D Hadamard transform.

    Parameters
    ----------
    x : Tensor
        Input tensor. Sizes along both dims must be powers of 2.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply transform.

    Returns
    -------
    Tensor
        2D Hadamard transformed tensor.
    """
    # Apply along first dimension
    result = self.hadamard.transform(x, dim=dim[0])
    # Apply along second dimension
    result = self.hadamard.transform(result, dim=dim[1])

    if self.normalized:
        n1 = x.shape[dim[0]]
        n2 = x.shape[dim[1]]
        result = result / math.sqrt(n1 * n2)

    return result
inverse_transform
inverse_transform(x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor

Apply inverse 2D Hadamard transform.

Parameters:

Name Type Description Default
x Tensor

Hadamard coefficients.

required
dim tuple[int, int]

Dimensions along which to apply inverse transform.

(-2, -1)

Returns:

Type Description
Tensor

Inverse transformed tensor.

Source code in spectrans/transforms/hadamard.py
def inverse_transform(self, x: Tensor, dim: tuple[int, int] = (-2, -1)) -> Tensor:
    """Apply inverse 2D Hadamard transform.

    Parameters
    ----------
    x : Tensor
        Hadamard coefficients.
    dim : tuple[int, int], default=(-2, -1)
        Dimensions along which to apply inverse transform.

    Returns
    -------
    Tensor
        Inverse transformed tensor.
    """
    if self.normalized:
        return self.transform(x, dim)
    else:
        result = self.transform(x, dim)
        n1 = x.shape[dim[0]]
        n2 = x.shape[dim[1]]
        return result / (n1 * n2)

SequencyHadamardTransform

SequencyHadamardTransform(normalized: bool = True)

Bases: OrthogonalTransform

Sequency-ordered Hadamard Transform.

The sequency ordering arranges basis functions by number of zero-crossings, similar to frequency ordering in Fourier transforms.

Parameters:

Name Type Description Default
normalized bool

Whether to normalize for orthogonality.

True

Methods:

Name Description
transform

Apply sequency-ordered Hadamard transform.

inverse_transform

Apply inverse sequency-ordered Hadamard transform.

Source code in spectrans/transforms/hadamard.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
    self.hadamard = HadamardTransform(normalized=normalized)
Functions
transform
transform(x: Tensor, dim: int = -1) -> Tensor

Apply sequency-ordered Hadamard transform.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Size along dim must be power of 2.

required
dim int

Dimension along which to apply transform.

-1

Returns:

Type Description
Tensor

Sequency-ordered Hadamard coefficients.

Source code in spectrans/transforms/hadamard.py
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply sequency-ordered Hadamard transform.

    Parameters
    ----------
    x : Tensor
        Input tensor. Size along dim must be power of 2.
    dim : int, default=-1
        Dimension along which to apply transform.

    Returns
    -------
    Tensor
        Sequency-ordered Hadamard coefficients.
    """
    # Apply standard Hadamard transform
    result = self.hadamard.transform(x, dim)

    # Reorder to sequency ordering
    n = x.shape[dim]
    indices = self._get_sequency_indices(n).to(x.device)

    result = result[..., indices] if dim == -1 else torch.index_select(result, dim, indices)

    return result
inverse_transform
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse sequency-ordered Hadamard transform.

Parameters:

Name Type Description Default
x Tensor

Sequency-ordered Hadamard coefficients.

required
dim int

Dimension along which to apply inverse transform.

-1

Returns:

Type Description
Tensor

Inverse transformed tensor.

Source code in spectrans/transforms/hadamard.py
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse sequency-ordered Hadamard transform.

    Parameters
    ----------
    x : Tensor
        Sequency-ordered Hadamard coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse transform.

    Returns
    -------
    Tensor
        Inverse transformed tensor.
    """
    n = x.shape[dim]

    # Get inverse permutation
    indices = self._get_sequency_indices(n).to(x.device)
    inverse_indices = torch.zeros_like(indices)
    inverse_indices[indices] = torch.arange(n, device=x.device)

    # Reorder from sequency to natural ordering
    if dim == -1:
        x_reordered = x[..., inverse_indices]
    else:
        x_reordered = torch.index_select(x, dim, inverse_indices)

    # Apply inverse Hadamard
    return self.hadamard.inverse_transform(x_reordered, dim)

SlantTransform

SlantTransform(normalized: bool = True)

Bases: OrthogonalTransform

Slant Transform.

The Slant transform is similar to Hadamard but with varying basis function slopes, providing better energy compaction for certain signals.

Parameters:

Name Type Description Default
normalized bool

Whether to normalize for orthogonality.

True

Methods:

Name Description
transform

Apply Slant transform.

inverse_transform

Apply inverse Slant transform.

Source code in spectrans/transforms/hadamard.py
def __init__(self, normalized: bool = True):
    super().__init__()
    self.normalized = normalized
Functions
transform
transform(x: Tensor, dim: int = -1) -> Tensor

Apply Slant transform.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Size along dim should be power of 2.

required
dim int

Dimension along which to apply transform.

-1

Returns:

Type Description
Tensor

Slant transformed tensor.

Source code in spectrans/transforms/hadamard.py
def transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply Slant transform.

    Parameters
    ----------
    x : Tensor
        Input tensor. Size along dim should be power of 2.
    dim : int, default=-1
        Dimension along which to apply transform.

    Returns
    -------
    Tensor
        Slant transformed tensor.
    """
    n = x.shape[dim]

    # Check if n is power of 2
    if n & (n - 1) != 0:
        raise ValueError(f"Slant transform works best with size as power of 2, got {n}")

    # Create Slant matrix
    slant_matrix = self._create_slant_matrix(n, x.device, x.dtype)

    # Apply transform via matrix multiplication
    if dim == -1 or dim == x.ndim - 1:
        result = torch.matmul(x, slant_matrix.T)
    else:
        x_moved = x.transpose(dim, -1)
        result = torch.matmul(x_moved, slant_matrix.T)
        result = result.transpose(dim, -1)

    return result
inverse_transform
inverse_transform(x: Tensor, dim: int = -1) -> Tensor

Apply inverse Slant transform.

Parameters:

Name Type Description Default
x Tensor

Slant coefficients.

required
dim int

Dimension along which to apply inverse transform.

-1

Returns:

Type Description
Tensor

Inverse transformed tensor.

Source code in spectrans/transforms/hadamard.py
def inverse_transform(self, x: Tensor, dim: int = -1) -> Tensor:
    """Apply inverse Slant transform.

    Parameters
    ----------
    x : Tensor
        Slant coefficients.
    dim : int, default=-1
        Dimension along which to apply inverse transform.

    Returns
    -------
    Tensor
        Inverse transformed tensor.
    """
    n = x.shape[dim]

    # Create Slant matrix (orthogonal, so inverse is transpose)
    slant_matrix = self._create_slant_matrix(n, x.device, x.dtype)

    # Apply inverse transform
    if dim == -1 or dim == x.ndim - 1:
        result = torch.matmul(x, slant_matrix)
    else:
        x_moved = x.transpose(dim, -1)
        result = torch.matmul(x_moved, slant_matrix)
        result = result.transpose(dim, -1)

    return result

DWT1D

DWT1D(wavelet: WaveletType = 'db4', levels: int = 1, mode: str = 'symmetric')

Bases: MultiResolutionTransform

PyWavelets-compatible 1D Discrete Wavelet Transform.

This implementation exactly matches PyWavelets behavior based on comprehensive C code analysis. It supports multi-level decomposition and achieves perfect reconstruction (< 1e-6 error) for all wavelets.

Parameters:

Name Type Description Default
wavelet WaveletType

Wavelet type (e.g., 'db1', 'db2', 'db4', 'db8', 'sym2', 'coif1').

'db4'
levels int

Number of decomposition levels.

1
mode str

Boundary handling mode (currently only 'symmetric' supported).

'symmetric'

Attributes:

Name Type Description
wavelet str

The wavelet type being used.

levels int

Number of decomposition levels.

mode str

Boundary handling mode.

dec_lo Tensor

Low-pass decomposition filter.

dec_hi Tensor

High-pass decomposition filter.

rec_lo Tensor

Low-pass reconstruction filter.

rec_hi Tensor

High-pass reconstruction filter.

filter_length int

Length of the wavelet filters.

Examples:

>>> dwt = DWT1D(wavelet='db4', levels=3)
>>> x = torch.randn(16, 256)  # batch_size=16, length=256
>>> cA, cD_list = dwt.decompose(x)
>>> print(f"Approximation shape: {cA.shape}")
>>> print(f"Number of detail levels: {len(cD_list)}")
>>> x_rec = dwt.reconstruct((cA, cD_list))
>>> error = torch.max(torch.abs(x - x_rec))
>>> print(f"Reconstruction error: {error:.2e}")

Methods:

Name Description
decompose

Multi-level DWT decomposition.

reconstruct

Multi-level DWT reconstruction.

Source code in spectrans/transforms/wavelet.py
def __init__(self, wavelet: WaveletType = "db4", levels: int = 1, mode: str = "symmetric"):
    super().__init__(levels=levels)
    self.wavelet = wavelet
    self.mode = mode

    if mode != "symmetric":
        msg = f"Mode '{mode}' not yet supported. Only 'symmetric' is implemented."
        raise NotImplementedError(msg)

    # Get filter coefficients from PyWavelets
    dec_lo, dec_hi, rec_lo, rec_hi = get_wavelet_filters(wavelet)

    # Register as buffers (not parameters) since they're fixed
    # Convert to float32 for efficiency in neural networks
    self.register_buffer("dec_lo", dec_lo.float())
    self.register_buffer("dec_hi", dec_hi.float())
    self.register_buffer("rec_lo", rec_lo.float())
    self.register_buffer("rec_hi", rec_hi.float())

    self.filter_length = len(dec_lo)
Functions
decompose
decompose(x: Tensor, levels: int | None = None, dim: int = -1) -> tuple[Tensor, list[Tensor]]

Multi-level DWT decomposition.

Recursively applies DWT to approximation coefficients.

Parameters:

Name Type Description Default
x Tensor

Input signal.

required
levels int | None

Number of levels. If None, uses self.levels.

None
dim int

Dimension to decompose along.

-1

Returns:

Type Description
tuple[Tensor, list[Tensor]]

Tuple of (approximation, [detail_1, ..., detail_N]) where details are ordered from finest to coarsest.

Source code in spectrans/transforms/wavelet.py
def decompose(
    self, x: Tensor, levels: int | None = None, dim: int = -1
) -> tuple[Tensor, list[Tensor]]:
    """Multi-level DWT decomposition.

    Recursively applies DWT to approximation coefficients.

    Parameters
    ----------
    x : Tensor
        Input signal.
    levels : int | None
        Number of levels. If None, uses self.levels.
    dim : int
        Dimension to decompose along.

    Returns
    -------
    tuple[Tensor, list[Tensor]]
        Tuple of (approximation, [detail_1, ..., detail_N])
        where details are ordered from finest to coarsest.
    """
    if levels is None:
        levels = self.levels

    current = x
    details = []

    # Apply DWT recursively
    for _ in range(levels):
        cA, cD = self._single_dwt(current, dim=dim)
        details.append(cD)
        current = cA

    return current, details
reconstruct
reconstruct(coeffs: tuple[Tensor, list[Tensor]], dim: int = -1, output_len: int | None = None) -> Tensor

Multi-level DWT reconstruction.

Parameters:

Name Type Description Default
coeffs tuple[Tensor, list[Tensor]]

Tuple of (approximation, [detail_1, ..., detail_N]).

required
dim int

Dimension to reconstruct along.

-1
output_len int | None

Desired output length. If provided, the reconstructed signal will be trimmed or padded to this length.

None

Returns:

Type Description
Tensor

Reconstructed signal.

Source code in spectrans/transforms/wavelet.py
def reconstruct(
    self, coeffs: tuple[Tensor, list[Tensor]], dim: int = -1, output_len: int | None = None
) -> Tensor:
    """Multi-level DWT reconstruction.

    Parameters
    ----------
    coeffs : tuple[Tensor, list[Tensor]]
        Tuple of (approximation, [detail_1, ..., detail_N]).
    dim : int
        Dimension to reconstruct along.
    output_len : int | None
        Desired output length. If provided, the reconstructed signal
        will be trimmed or padded to this length.

    Returns
    -------
    Tensor
        Reconstructed signal.
    """
    cA, details = coeffs
    current = cA

    # Reconstruct from coarsest to finest (reverse order)
    for _i, cD in enumerate(reversed(details)):
        # For multi-level, we need to handle size mismatches
        # The reconstructed signal from a coarser level may be slightly
        # longer than the detail coefficients from the finer level

        if current.shape[dim] > cD.shape[dim]:
            # Trim current to match cD size
            # This happens because IDWT can produce slightly longer output
            target_len = cD.shape[dim]
            if dim == -1 or dim == current.ndim - 1:
                current = current[..., :target_len]
            elif dim == 0:
                current = current[:target_len]
            else:
                # General case
                slices = [slice(None)] * current.ndim
                slices[dim] = slice(0, target_len)
                current = current[tuple(slices)]
        elif current.shape[dim] < cD.shape[dim]:
            # This shouldn't happen with correct decomposition
            raise ValueError(
                f"Reconstructed signal smaller than detail coefficients. "
                f"current shape: {current.shape}, cD shape: {cD.shape}"
            )

        # Now they should have matching sizes
        current = self._single_idwt(current, cD, dim=dim)

    # Trim to desired output length if specified
    if output_len is not None and current.shape[dim] != output_len:
        if dim == -1 or dim == current.ndim - 1:
            current = current[..., :output_len]
        elif dim == 0:
            current = current[:output_len]
        else:
            slices = [slice(None)] * current.ndim
            slices[dim] = slice(0, output_len)
            current = current[tuple(slices)]

    return current

DWT2D

DWT2D(wavelet: WaveletType = 'db4', levels: int = 1, mode: str = 'symmetric')

Bases: MultiResolutionTransform2D

PyWavelets-compatible 2D Discrete Wavelet Transform.

Implements 2D DWT using separable 1D transforms, applying DWT along each dimension sequentially. Returns coefficients in the standard format: (LL, [(LH, HL, HH) per level]).

Parameters:

Name Type Description Default
wavelet WaveletType

Wavelet type to use.

'db4'
levels int

Number of decomposition levels.

1
mode str

Boundary handling mode.

'symmetric'

Attributes:

Name Type Description
wavelet str

The wavelet type.

levels int

Number of decomposition levels.

mode str

Boundary handling mode.

dwt1d DWT1D

1D DWT instance used for separable transforms.

Examples:

>>> dwt2d = DWT2D(wavelet='db2', levels=2)
>>> image = torch.randn(4, 64, 64)  # batch of 4 images
>>> ll, detail_bands = dwt2d.decompose(image)
>>> print(f"LL shape: {ll.shape}")
>>> for i, (lh, hl, hh) in enumerate(detail_bands):
...     print(f"Level {i+1} - LH: {lh.shape}, HL: {hl.shape}, HH: {hh.shape}")
>>> reconstructed = dwt2d.reconstruct((ll, detail_bands))

Methods:

Name Description
decompose

Multi-level 2D DWT decomposition.

reconstruct

Multi-level 2D DWT reconstruction.

Source code in spectrans/transforms/wavelet.py
def __init__(self, wavelet: WaveletType = "db4", levels: int = 1, mode: str = "symmetric"):
    super().__init__(levels=levels)
    self.wavelet = wavelet
    self.mode = mode

    # Use 1D DWT for separable 2D transform
    self.dwt1d = DWT1D(wavelet=wavelet, levels=1, mode=mode)
Functions
decompose
decompose(x: Tensor, levels: int | None = None, dim: tuple[int, int] = (-2, -1)) -> tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]

Multi-level 2D DWT decomposition.

Parameters:

Name Type Description Default
x Tensor

Input 2D tensor.

required
levels int | None

Number of levels. If None, uses self.levels.

None
dim tuple[int, int]

Dimensions to decompose along.

(-2, -1)

Returns:

Type Description
tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]

Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention where HL is horizontal detail, LH is vertical detail, HH is diagonal detail.

Source code in spectrans/transforms/wavelet.py
def decompose(
    self, x: Tensor, levels: int | None = None, dim: tuple[int, int] = (-2, -1)
) -> tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]:
    """Multi-level 2D DWT decomposition.

    Parameters
    ----------
    x : Tensor
        Input 2D tensor.
    levels : int | None
        Number of levels. If None, uses self.levels.
    dim : tuple[int, int]
        Dimensions to decompose along.

    Returns
    -------
    tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
        Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention
        where HL is horizontal detail, LH is vertical detail, HH is diagonal detail.
    """
    if levels is None:
        levels = self.levels

    current = x
    detail_bands = []

    for _ in range(levels):
        ll, lh, hl, hh = self._single_level_2d(current, dim=dim)
        # PyWavelets returns (cH, cV, cD) = (HL, LH, HH)
        # So we append (HL, LH, HH) to match
        detail_bands.append((hl, lh, hh))
        current = ll

    return current, detail_bands
reconstruct
reconstruct(coeffs: tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]], dim: tuple[int, int] = (-2, -1)) -> Tensor

Multi-level 2D DWT reconstruction.

Parameters:

Name Type Description Default
coeffs tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]

Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention.

required
dim tuple[int, int]

Dimensions to reconstruct along.

(-2, -1)

Returns:

Type Description
Tensor

Reconstructed 2D tensor.

Source code in spectrans/transforms/wavelet.py
def reconstruct(
    self,
    coeffs: tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]],
    dim: tuple[int, int] = (-2, -1),
) -> Tensor:
    """Multi-level 2D DWT reconstruction.

    Parameters
    ----------
    coeffs : tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
        Tuple of (LL, [(HL, LH, HH) per level]) following PyWavelets convention.
    dim : tuple[int, int]
        Dimensions to reconstruct along.

    Returns
    -------
    Tensor
        Reconstructed 2D tensor.
    """
    ll, detail_bands = coeffs
    current = ll

    # Reconstruct from coarsest to finest
    # detail_bands contains (HL, LH, HH) tuples following PyWavelets convention
    for hl, lh, hh in reversed(detail_bands):
        current = self._single_level_2d_reconstruct(current, lh, hl, hh, dim=dim)

    return current