Base Transform Classes¶
spectrans.transforms.base ¶
Base classes and interfaces for spectral transforms.
This module defines the foundational interfaces for all spectral transform implementations used in spectral transformer neural networks. The class hierarchy accommodates the mathematical diversity of different transform types while maintaining consistent APIs for integration with neural network layers.
The transforms are categorized by their mathematical properties and interface requirements, enabling type-safe composition and clear separation of concerns between different spectral analysis approaches.
Classes:
| Name | Description |
|---|---|
Transform |
Minimal base class for all transforms. |
SpectralTransform |
Interface for simple 1D spectral transforms (FFT, DCT, DST, Hadamard). |
SpectralTransform2D |
Interface for 2D spectral transforms with tuple dimension parameters. |
MultiResolutionTransform |
Interface for transforms returning multi-resolution coefficients (DWT). |
MultiResolutionTransform2D |
Interface for 2D multi-resolution transforms. |
OrthogonalTransform |
Base class for orthogonal transforms preserving inner products. |
UnitaryTransform |
Base class for unitary transforms preserving complex inner products. |
NeuralSpectralTransform |
Base class for learnable transforms with neural network parameters. |
InvertibleTransform |
Base class for transforms guaranteeing exact invertibility. |
AdaptiveTransform |
Base class for transforms with learnable basis functions. |
Examples:
Implementing a custom 1D spectral transform:
>>> from spectrans.transforms.base import SpectralTransform
>>> import torch
>>> class IdentityTransform(SpectralTransform):
... def transform(self, x, dim=-1):
... return x
... def inverse_transform(self, x, dim=-1):
... return x
Using the invertibility checker:
>>> from spectrans.transforms.base import InvertibleTransform
>>> from spectrans.utils.fft import safe_fft, safe_ifft
>>> class MyTransform(InvertibleTransform):
... def transform(self, x, dim=-1):
... return safe_fft(x, dim=dim)
... def inverse_transform(self, x, dim=-1):
... return safe_ifft(x, dim=dim)
>>> transform = MyTransform()
>>> test_input = torch.randn(10, 256, dtype=torch.complex64)
>>> is_invertible = transform.check_invertibility(test_input)
Implementing a multi-resolution transform:
>>> from spectrans.transforms.base import MultiResolutionTransform
>>> class SimpleWavelet(MultiResolutionTransform):
... def decompose(self, x, levels=None, dim=-1):
... # Implementation returns (approximation, [details])
... approx = x[..., ::2] # Simplified downsampling
... detail = x[..., 1::2]
... return approx, [detail]
... def reconstruct(self, coeffs, dim=-1):
... # Simplified reconstruction
... approx, details = coeffs
... return torch.stack([approx, details[0]], dim=-1).flatten(-2)
Notes
Design Principles:
-
Mathematical Correctness: Each base class enforces the mathematical properties of its transform family (orthogonality, unitarity, etc.)
-
Interface Segregation: Different transform types have separate interfaces to avoid forcing incompatible operations into the same signature
-
Composition Support: All transforms can be composed and chained while maintaining proper mathematical properties
-
Gradient Compatibility: All transforms support automatic differentiation for end-to-end neural network training
Transform Categories by Mathematical Properties:
Simple Spectral Transforms (SpectralTransform): - Map Tensor → Tensor along a single dimension - Examples: 1D FFT, DCT, DST, Hadamard - Interface: transform(), inverse_transform() with dim parameter
2D Spectral Transforms (SpectralTransform2D): - Operate on 2D data with tuple dimension parameters - Examples: 2D FFT for AFNO, 2D DCT - Interface: transform(), inverse_transform() with dim tuple
Multi-Resolution Transforms (MultiResolutionTransform): - Decompose into multiple resolution levels - Examples: Discrete Wavelet Transform (DWT) - Interface: decompose(), reconstruct() returning coefficient tuples
Property-Based Classifications: - Orthogonal: Preserve real inner products (DCT, Hadamard) - Unitary: Preserve complex inner products (FFT) - Invertible: Guarantee numerical invertibility with tolerance checking - Adaptive: Learn transform parameters during training
See Also
spectrans.transforms.fourier : FFT implementations spectrans.transforms.cosine : DCT/DST implementations spectrans.transforms.hadamard : Hadamard transform implementations spectrans.transforms.wavelet : Wavelet transform implementations spectrans.core.types : Type definitions for transform interfaces
Classes¶
Transform ¶
Bases: Module, ABC
Minimal common base class for all transforms.
This provides the minimal interface that all transforms share, allowing for flexible composition through the registry system without forcing incompatible mathematical operations into the same interface signatures.
SpectralTransform ¶
Bases: Transform
Base class for simple spectral transforms.
For transforms that map Tensor → Tensor along a specified dimension, such as FFT, DCT, DST, and Hadamard transforms. These transforms operate on a single dimension and return tensors of the same shape.
Mathematical operations supported: - Fourier transforms (FFT, RFFT) - Discrete Cosine Transform (DCT) - Discrete Sine Transform (DST) - Hadamard transform
Methods:
| Name | Description |
|---|---|
transform |
Apply forward transform along specified dimension. |
inverse_transform |
Apply inverse transform along specified dimension. |
Attributes:
| Name | Type | Description |
|---|---|---|
is_orthogonal |
bool
|
Whether the transform is orthogonal. |
is_unitary |
bool
|
Whether the transform is unitary. |
Attributes¶
is_orthogonal
property
¶
Whether the transform is orthogonal.
Returns:
| Type | Description |
|---|---|
bool
|
True if the transform preserves inner products. |
is_unitary
property
¶
Whether the transform is unitary.
Returns:
| Type | Description |
|---|---|
bool
|
True if the transform preserves complex inner products. |
Functions¶
transform
abstractmethod
¶
Apply forward transform along specified dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor to transform. |
required |
dim
|
int
|
Dimension along which to apply the transform. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Transformed tensor with same shape as input. |
Source code in spectrans/transforms/base.py
inverse_transform
abstractmethod
¶
Apply inverse transform along specified dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Transformed tensor to invert. |
required |
dim
|
int
|
Dimension along which to apply the inverse transform. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Inverse transformed tensor with same shape as input. |
Source code in spectrans/transforms/base.py
SpectralTransform2D ¶
Bases: Transform
Base class for 2D spectral transforms.
For transforms that operate on 2D data with tuple dimension parameters, such as 2D FFT, 2D DCT, or other transforms that require operating along multiple dimensions simultaneously.
These transforms are essential for AFNO and other architectures that perform 2D Fourier operations as specified in the implementation plan.
Methods:
| Name | Description |
|---|---|
transform |
Apply forward 2D transform along specified dimensions. |
inverse_transform |
Apply inverse 2D transform along specified dimensions. |
Functions¶
transform
abstractmethod
¶
Apply forward 2D transform along specified dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor to transform. |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply the transform. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Transformed tensor with same shape as input. |
Source code in spectrans/transforms/base.py
inverse_transform
abstractmethod
¶
Apply inverse 2D transform along specified dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Transformed tensor to invert. |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply the inverse transform. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Inverse transformed tensor with same shape as input. |
Source code in spectrans/transforms/base.py
MultiResolutionTransform ¶
Bases: Transform
Base class for multi-resolution transforms.
For transforms that decompose signals into multiple components at different resolution levels, such as Discrete Wavelet Transform (DWT).
These transforms are mathematically different from simple spectral transforms as they return multiple components: - Approximation coefficients at the coarsest level - Detail coefficients at each level
This matches the mathematical formulation: DWT(x) = {c_{A_J}, {c_{D_j}}_{j=1}^J}
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
levels
|
int
|
Number of decomposition levels. |
1
|
Methods:
| Name | Description |
|---|---|
decompose |
Decompose signal into multiple resolution levels. |
reconstruct |
Reconstruct signal from multi-resolution coefficients. |
Source code in spectrans/transforms/base.py
Functions¶
decompose
abstractmethod
¶
Decompose signal into multiple resolution levels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor to decompose. |
required |
levels
|
int | None
|
Number of levels. If None, use self.levels. |
None
|
dim
|
int
|
Dimension along which to apply decomposition. |
-1
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, list[Tensor]]
|
Tuple of (approximation_coefficients, detail_coefficients_list) where detail_coefficients_list contains coefficients from coarsest to finest level. |
Source code in spectrans/transforms/base.py
reconstruct
abstractmethod
¶
Reconstruct signal from multi-resolution coefficients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coeffs
|
tuple[Tensor, list[Tensor]]
|
Tuple of (approximation_coefficients, detail_coefficients_list). |
required |
dim
|
int
|
Dimension along which to apply reconstruction. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed tensor. |
Source code in spectrans/transforms/base.py
MultiResolutionTransform2D ¶
Bases: Transform
Base class for 2D multi-resolution transforms.
For 2D wavelet transforms and other multi-resolution transforms that operate on 2D data. Returns coefficients in the standard 2D wavelet format: (LL, [(LH, HL, HH) per level]).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
levels
|
int
|
Number of decomposition levels. |
1
|
Methods:
| Name | Description |
|---|---|
decompose |
Decompose 2D signal into multiple resolution levels. |
reconstruct |
Reconstruct 2D signal from multi-resolution coefficients. |
Source code in spectrans/transforms/base.py
Functions¶
decompose
abstractmethod
¶
decompose(x: Tensor, levels: int | None = None, dim: tuple[int, int] = (-2, -1)) -> tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
Decompose 2D signal into multiple resolution levels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input 2D tensor to decompose. |
required |
levels
|
int | None
|
Number of levels. If None, use self.levels. |
None
|
dim
|
tuple[int, int]
|
Dimensions along which to apply decomposition. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
|
Tuple of (LL_coefficients, [(LH, HL, HH) per level]) where each tuple contains the three high-frequency subbands. |
Source code in spectrans/transforms/base.py
reconstruct
abstractmethod
¶
reconstruct(coeffs: tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]], dim: tuple[int, int] = (-2, -1)) -> Tensor
Reconstruct 2D signal from multi-resolution coefficients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coeffs
|
tuple[Tensor, list[tuple[Tensor, Tensor, Tensor]]]
|
Tuple of (LL_coefficients, [(LH, HL, HH) per level]). |
required |
dim
|
tuple[int, int]
|
Dimensions along which to apply reconstruction. |
(-2, -1)
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed 2D tensor. |
Source code in spectrans/transforms/base.py
OrthogonalTransform ¶
Bases: SpectralTransform
Base class for orthogonal transforms.
Orthogonal transforms preserve inner products and have the property that their inverse is their transpose. This includes DCT, DST, and Hadamard transforms.
Attributes:
| Name | Type | Description |
|---|---|---|
is_orthogonal |
bool
|
Orthogonal transforms preserve inner products. |
UnitaryTransform ¶
Bases: SpectralTransform
Base class for unitary transforms.
Unitary transforms preserve complex inner products and have the property that their inverse is their conjugate transpose. This includes the Discrete Fourier Transform (DFT/FFT).
Attributes:
| Name | Type | Description |
|---|---|---|
is_unitary |
bool
|
Unitary transforms preserve complex inner products. |
NeuralSpectralTransform ¶
Bases: SpectralTransform
Base class for learnable spectral transforms.
This class is for transforms that can learn their parameters during training, such as learnable filters in the frequency domain.
Methods:
| Name | Description |
|---|---|
forward |
Forward pass through the neural spectral transform. |
Functions¶
forward ¶
Forward pass through the neural spectral transform.
By default, applies the transform operation. Subclasses can override this for more complex learned behaviors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor. |
Source code in spectrans/transforms/base.py
InvertibleTransform ¶
Bases: SpectralTransform
Base class for transforms that guarantee exact invertibility.
These transforms ensure that inverse_transform(transform(x)) == x within numerical precision, which is important for certain architectures that require perfect reconstruction.
Methods:
| Name | Description |
|---|---|
check_invertibility |
Check if transform is invertible for given input. |
Functions¶
check_invertibility ¶
Check if transform is invertible for given input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor to test. |
required |
dim
|
int
|
Dimension along which to test invertibility. |
-1
|
rtol
|
float
|
Relative tolerance for comparison. |
1e-5
|
atol
|
float
|
Absolute tolerance for comparison. |
1e-8
|
Returns:
| Type | Description |
|---|---|
bool
|
True if transform is invertible within tolerance. |
Source code in spectrans/transforms/base.py
AdaptiveTransform ¶
Bases: NeuralSpectralTransform
Base class for adaptive transforms with learnable parameters.
Adaptive transforms can learn their basis functions or transformation parameters from data. This is useful for applications where the optimal spectral representation depends on the specific data distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension size. |
required |
learnable
|
bool
|
Whether transform parameters are learnable. |
True
|