Skip to content

Mixing Layers

spectrans.layers.mixing

Spectral mixing layer implementations for token mixing.

Provides spectral mixing layers as alternatives to attention mechanisms. These layers operate in frequency domains using transforms like FFT, maintaining linear or log-linear computational complexity for token mixing operations.

Mixing layers implement different mathematical approaches including parameter-free Fourier mixing (FNet style), learnable complex filters in frequency domain (GFNet style), and variants with adaptive initialization and multi-dimensional operations.

Modules:

Name Description
afno

Adaptive Fourier Neural Operator mixing implementations.

base

Base classes and interfaces for mixing layers.

fourier

Fourier transform-based mixing layers.

global_filter

Global filter networks with learnable parameters.

wavelet

Wavelet transform-based mixing layers.

Classes:

Name Description
AFNOMixing

Adaptive Fourier Neural Operator with mode truncation.

AdaptiveGlobalFilter

Enhanced global filter with adaptive initialization.

FilterMixingLayer

Base class for learnable frequency domain filters.

FourierMixing

2D FFT mixing for both sequence and feature dimensions.

FourierMixing1D

1D FFT mixing along sequence dimension only.

GlobalFilterMixing

Learnable complex filters in frequency domain.

GlobalFilterMixing2D

2D variant with filtering in both dimensions.

MixingLayer

Base class for spectral mixing operations.

RealFourierMixing

Memory-efficient real FFT variant.

SeparableFourierMixing

Configurable sequence and/or feature mixing.

UnitaryMixingLayer

Base class for energy-preserving mixing transforms.

WaveletMixing

1D wavelet mixing using discrete wavelet transform.

WaveletMixing2D

2D wavelet mixing for spatial data processing.

Examples:

Basic Fourier mixing:

>>> from spectrans.layers.mixing import FourierMixing
>>> mixer = FourierMixing(hidden_dim=768)
>>> output = mixer(input_tensor)

Global filter with learnable parameters:

>>> from spectrans.layers.mixing import GlobalFilterMixing
>>> filter_mixer = GlobalFilterMixing(hidden_dim=768, sequence_length=512)
>>> filtered_output = filter_mixer(input_tensor)

Adaptive filtering:

>>> from spectrans.layers.mixing import AdaptiveGlobalFilter
>>> adaptive_mixer = AdaptiveGlobalFilter(
...     hidden_dim=768, sequence_length=512,
...     adaptive_initialization=True, filter_regularization=0.01
... )
>>> adaptive_output = adaptive_mixer(input_tensor)
Notes

Complexity Comparison:

Traditional attention has \(O(n^2 d)\) complexity. Fourier mixing reduces this to \(O(nd \log n)\). Global filtering uses \(O(nd \log n)\) complexity plus learnable parameters.

All mixing layers support batch processing with consistent behavior, gradient computation for end-to-end training, shape preservation where output shape equals input shape, and mathematical property verification for energy and orthogonality.

References

James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.

Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, and Jie Zhou. 2021. Global filter networks for image classification. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021), pages 980-993.

John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).

See Also

spectrans.layers.mixing.base : Base classes and interfaces. spectrans.transforms : Underlying spectral transform implementations. spectrans.blocks : Transformer blocks that use these mixing layers.

Classes

AFNOMixing

AFNOMixing(hidden_dim: int, max_sequence_length: int, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0)

Bases: MixingLayer

Adaptive Fourier Neural Operator mixing layer.

This layer performs efficient token mixing by applying learnable transformations in the truncated Fourier domain, significantly reducing computational cost while maintaining model expressiveness.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input/output tensors.

required
max_sequence_length int

Maximum sequence length the model will process.

required
modes_seq int | None

Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2.

None
modes_hidden int | None

Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2.

None
mlp_ratio float

Expansion ratio for the MLP in Fourier domain. Default is 2.0.

2.0
activation str

Activation function for MLP. Default is 'gelu'.

'gelu'
dropout float

Dropout probability for MLP. Default is 0.0.

0.0

Attributes:

Name Type Description
hidden_dim int

Hidden dimension size.

max_sequence_length int

Maximum supported sequence length.

modes_seq int

Number of retained Fourier modes in sequence dimension.

modes_hidden int

Number of retained Fourier modes in hidden dimension.

mlp_ratio float

MLP expansion ratio.

fourier_weight Parameter

Complex-valued learnable weights for Fourier modes.

mlp Sequential

MLP applied in Fourier domain.

Examples:

>>> import torch
>>> layer = AFNOMixing(hidden_dim=768, max_sequence_length=512, modes_seq=128)
>>> x = torch.randn(32, 512, 768)
>>> output = layer(x)
>>> print(output.shape)
torch.Size([32, 512, 768])

Methods:

Name Description
forward

Apply AFNO mixing to input tensor.

get_spectral_properties

Get mathematical properties of AFNO operation.

from_config

Create AFNOMixing layer from configuration.

Source code in spectrans/layers/mixing/afno.py
def __init__(
    self,
    hidden_dim: int,
    max_sequence_length: int,
    modes_seq: int | None = None,
    modes_hidden: int | None = None,
    mlp_ratio: float = 2.0,
    activation: ActivationType = "gelu",
    dropout: float = 0.0,
):
    super().__init__(hidden_dim=hidden_dim, dropout=dropout)

    self.max_sequence_length = max_sequence_length

    # Set default mode truncation if not specified
    self.modes_seq = modes_seq if modes_seq is not None else max_sequence_length // 2
    self.modes_hidden = modes_hidden if modes_hidden is not None else hidden_dim // 2

    # Ensure modes don't exceed actual dimensions (for rfft)
    # For rfft, the last dimension has size n//2 + 1
    self.modes_seq = min(self.modes_seq, max_sequence_length)
    self.modes_hidden = min(self.modes_hidden, hidden_dim // 2 + 1)

    self.mlp_ratio = mlp_ratio

    # Complex-valued learnable weights for Fourier modes
    # We use real FFT, so last dimension is reduced
    scale = 1 / (self.modes_seq * self.modes_hidden)
    self.fourier_weight = nn.Parameter(
        torch.randn(self.modes_seq, self.modes_hidden, 2) * scale
    )

    # MLP in Fourier domain
    mlp_hidden_dim = int(hidden_dim * mlp_ratio)

    # Activation function
    activation_fn: nn.Module
    if activation == "gelu":
        activation_fn = nn.GELU()
    elif activation == "relu":
        activation_fn = nn.ReLU()
    elif activation == "silu":
        activation_fn = nn.SiLU()
    elif activation == "tanh":
        activation_fn = nn.Tanh()
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "identity":
        activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unsupported activation: {activation}")

    # MLP operates on real and imaginary parts concatenated
    self.mlp = nn.Sequential(
        nn.Linear(self.modes_seq * self.modes_hidden * 2, mlp_hidden_dim),
        activation_fn,
        nn.Dropout(dropout),
        nn.Linear(mlp_hidden_dim, self.modes_seq * self.modes_hidden * 2),
        nn.Dropout(dropout),
    )

    # Layer normalization
    self.norm = nn.LayerNorm(hidden_dim)

    self._init_weights()
Functions
forward
forward(x: Tensor) -> Tensor

Apply AFNO mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Output tensor of same shape as input.

Source code in spectrans/layers/mixing/afno.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply AFNO mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Output tensor of same shape as input.
    """
    batch_size, seq_len, hidden_dim = x.shape
    residual = x
    input_dtype = x.dtype

    # Convert to float32 for processing if needed
    if x.dtype != torch.float32:
        x = x.to(torch.float32)
        residual = residual.to(torch.float32)

    # Apply layer norm
    x = self.norm(x)

    # Pad if necessary to match max_sequence_length
    if seq_len < self.max_sequence_length:
        padding = self.max_sequence_length - seq_len
        x = F.pad(x, (0, 0, 0, padding))

    # Step 1: Transform to Fourier space using 2D FFT
    # Treat (sequence, hidden) as 2D spatial dimensions
    # Use safe wrapper to handle MKL issues
    x_ft = safe_rfft2(x, dim=(1, 2), norm="ortho")

    # Step 2: Mode truncation - keep only low-frequency modes
    x_ft_truncated = x_ft[:, : self.modes_seq, : self.modes_hidden]

    # Step 3: Apply learnable transformation in Fourier domain
    # First apply pointwise multiplication with learnable weights
    weight_complex = torch.view_as_complex(self.fourier_weight)
    x_ft_truncated = x_ft_truncated * weight_complex

    # Flatten for MLP processing
    batch_size_ft = x_ft_truncated.shape[0]
    x_ft_flat = torch.view_as_real(x_ft_truncated).reshape(batch_size_ft, -1)

    # Apply MLP
    x_ft_flat = self.mlp(x_ft_flat)

    # Reshape back to complex truncated form
    x_ft_truncated = x_ft_flat.reshape(batch_size_ft, self.modes_seq, self.modes_hidden, 2)
    x_ft_truncated = torch.view_as_complex(x_ft_truncated)

    # Step 4: Zero-pad back to original size
    x_ft_padded = torch.zeros(
        (batch_size, self.max_sequence_length, hidden_dim // 2 + 1),
        dtype=x_ft.dtype,
        device=x_ft.device,
    )
    x_ft_padded[:, : self.modes_seq, : self.modes_hidden] = x_ft_truncated

    # Step 5: Inverse FFT to get back to spatial domain
    # Use safe wrapper to handle MKL issues
    x_spatial = safe_irfft2(
        x_ft_padded, s=(self.max_sequence_length, hidden_dim), dim=(1, 2), norm="ortho"
    )

    # Remove padding if it was added
    if seq_len < self.max_sequence_length:
        x_spatial = x_spatial[:, :seq_len, :]

    # Step 6: Add residual connection
    output = residual + x_spatial

    # Convert back to original dtype if needed
    if output.dtype != input_dtype:
        output = output.to(input_dtype)

    return output
get_spectral_properties
get_spectral_properties() -> dict[str, bool]

Get mathematical properties of AFNO operation.

Returns:

Type Description
dict[str, bool]

Mathematical properties of the transform.

Source code in spectrans/layers/mixing/afno.py
def get_spectral_properties(self) -> dict[str, bool]:
    """Get mathematical properties of AFNO operation.

    Returns
    -------
    dict[str, bool]
        Mathematical properties of the transform.
    """
    return {
        "unitary": False,  # Not unitary due to mode truncation and MLP
        "real_output": True,  # Output is real-valued
        "frequency_domain": True,  # Operations in Fourier domain
        "energy_preserving": False,  # Energy not preserved due to truncation
        "learnable_parameters": True,  # Has learnable weights and MLP
        "translation_equivariant": False,  # Not equivariant due to MLP
        "mode_truncation": True,  # Uses Fourier mode truncation
        "adaptive": True,  # Adaptive filtering based on learned parameters
    }
from_config classmethod
from_config(config: AFNOMixingConfig) -> AFNOMixing

Create AFNOMixing layer from configuration.

Parameters:

Name Type Description Default
config AFNOMixingConfig

Configuration object with layer parameters.

required

Returns:

Type Description
AFNOMixing

Configured AFNO mixing layer.

Source code in spectrans/layers/mixing/afno.py
@classmethod
def from_config(cls, config: "AFNOMixingConfig") -> "AFNOMixing":
    """Create AFNOMixing layer from configuration.

    Parameters
    ----------
    config : AFNOMixingConfig
        Configuration object with layer parameters.

    Returns
    -------
    AFNOMixing
        Configured AFNO mixing layer.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        max_sequence_length=config.max_sequence_length,
        modes_seq=config.modes_seq,
        modes_hidden=config.modes_hidden,
        mlp_ratio=config.mlp_ratio,
        activation=config.activation,
        dropout=config.dropout,
    )

FilterMixingLayer

FilterMixingLayer(hidden_dim: int, sequence_length: int, dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True)

Bases: MixingLayer

Base class for frequency-domain filtering operations.

Filter mixing layers apply learnable filters in the frequency domain, enabling selective emphasis or suppression of frequency components for improved sequence modeling capabilities.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
sequence_length int

Expected sequence length for filter initialization.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filters are learnable parameters.

True

Attributes:

Name Type Description
sequence_length int

Expected sequence length.

learnable_filters bool

Whether filters are learnable.

Methods:

Name Description
get_spectral_properties

Get properties specific to filtering operations.

get_filter_response

Get the frequency response of the current filters.

analyze_frequency_response

Analyze the frequency response characteristics.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
):
    super().__init__(hidden_dim, dropout, norm_eps)
    self.sequence_length = sequence_length
    self.learnable_filters = learnable_filters
Functions
get_spectral_properties
get_spectral_properties() -> dict[str, Any]

Get properties specific to filtering operations.

Returns:

Type Description
dict[str, Any]

Dictionary containing filter-specific properties.

Source code in spectrans/layers/mixing/base.py
def get_spectral_properties(self) -> dict[str, Any]:
    """Get properties specific to filtering operations.

    Returns
    -------
    dict[str, Any]
        Dictionary containing filter-specific properties.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": self.learnable_filters,
        "selective_filtering": True,
        "complex_valued": True,
        "energy_preserving": False,  # Filtering can change energy
    }
get_filter_response abstractmethod
get_filter_response() -> Tensor

Get the frequency response of the current filters.

Returns:

Type Description
Tensor

Complex-valued frequency response of shape matching the filter parameters.

Source code in spectrans/layers/mixing/base.py
@abstractmethod
def get_filter_response(self) -> torch.Tensor:
    """Get the frequency response of the current filters.

    Returns
    -------
    torch.Tensor
        Complex-valued frequency response of shape matching the filter parameters.
    """
    pass
analyze_frequency_response
analyze_frequency_response() -> dict[str, Tensor]

Analyze the frequency response characteristics.

Returns:

Type Description
dict[str, Tensor]

Dictionary containing: - 'magnitude': Magnitude response - 'phase': Phase response - 'group_delay': Group delay response - 'passband_energy': Energy in different frequency bands

Source code in spectrans/layers/mixing/base.py
def analyze_frequency_response(self) -> dict[str, torch.Tensor]:
    """Analyze the frequency response characteristics.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing:
        - 'magnitude': Magnitude response
        - 'phase': Phase response
        - 'group_delay': Group delay response
        - 'passband_energy': Energy in different frequency bands
    """
    response = self.get_filter_response()

    magnitude = torch.abs(response)
    phase = torch.angle(response)

    # Compute group delay (negative derivative of phase)
    # For discrete signals, use finite differences
    phase_diff = torch.diff(phase, dim=-1)
    group_delay = -phase_diff

    # Analyze energy in different frequency bands
    total_energy = torch.sum(magnitude**2, dim=-1, keepdim=True)
    low_freq_energy = torch.sum(
        magnitude[..., : magnitude.size(-1) // 4] ** 2, dim=-1, keepdim=True
    )
    high_freq_energy = torch.sum(
        magnitude[..., 3 * magnitude.size(-1) // 4 :] ** 2, dim=-1, keepdim=True
    )

    return {
        "magnitude": magnitude,
        "phase": phase,
        "group_delay": group_delay,
        "total_energy": total_energy,
        "low_freq_energy": low_freq_energy / (total_energy + self.norm_eps),
        "high_freq_energy": high_freq_energy / (total_energy + self.norm_eps),
    }

MixingLayer

MixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05)

Bases: SpectralComponent

Base class for spectral mixing operations.

Mixing layers perform token mixing operations using various spectral transforms instead of traditional attention mechanisms. This class provides spectral-specific functionality including mathematical property verification and standardized interfaces for spectral transform operations.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability in normalization.

1e-5

Attributes:

Name Type Description
hidden_dim int

Hidden dimension of the model.

dropout Module

Dropout layer for regularization.

norm_eps float

Epsilon for numerical stability.

Methods:

Name Description
get_spectral_properties

Get mathematical properties of the spectral operation.

verify_shape_consistency

Verify that input and output shapes are consistent.

compute_spectral_norm

Compute spectral norm for analysis and regularization.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
):
    super().__init__()
    self.hidden_dim = hidden_dim
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
    self.norm_eps = norm_eps
Functions
get_spectral_properties abstractmethod
get_spectral_properties() -> dict[str, Any]

Get mathematical properties of the spectral operation.

Returns:

Type Description
dict[str, Any]

Dictionary containing mathematical properties such as: - 'unitary': bool, whether the transform is unitary - 'real_output': bool, whether output is guaranteed real - 'frequency_domain': bool, whether operation occurs in frequency domain - 'energy_preserving': bool, whether energy is preserved

Source code in spectrans/layers/mixing/base.py
@abstractmethod
def get_spectral_properties(self) -> dict[str, Any]:
    """Get mathematical properties of the spectral operation.

    Returns
    -------
    dict[str, Any]
        Dictionary containing mathematical properties such as:
        - 'unitary': bool, whether the transform is unitary
        - 'real_output': bool, whether output is guaranteed real
        - 'frequency_domain': bool, whether operation occurs in frequency domain
        - 'energy_preserving': bool, whether energy is preserved
    """
    pass
verify_shape_consistency
verify_shape_consistency(input_tensor: Tensor, output_tensor: Tensor) -> bool

Verify that input and output shapes are consistent.

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor to the mixing layer.

required
output_tensor Tensor

Output tensor from the mixing layer.

required

Returns:

Type Description
bool

True if shapes are consistent, False otherwise.

Source code in spectrans/layers/mixing/base.py
def verify_shape_consistency(
    self, input_tensor: torch.Tensor, output_tensor: torch.Tensor
) -> bool:
    """Verify that input and output shapes are consistent.

    Parameters
    ----------
    input_tensor : torch.Tensor
        Input tensor to the mixing layer.
    output_tensor : torch.Tensor
        Output tensor from the mixing layer.

    Returns
    -------
    bool
        True if shapes are consistent, False otherwise.
    """
    if input_tensor.shape != output_tensor.shape:
        return False

    # Verify batch dimension consistency
    if input_tensor.size(0) != output_tensor.size(0):
        return False

    # Verify sequence length consistency
    if input_tensor.size(1) != output_tensor.size(1):
        return False

    # Verify hidden dimension consistency
    return input_tensor.size(2) == output_tensor.size(2)
compute_spectral_norm
compute_spectral_norm(tensor: Tensor) -> Tensor

Compute spectral norm for analysis and regularization.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor to compute spectral norm for.

required

Returns:

Type Description
Tensor

Spectral norm of the input tensor.

Source code in spectrans/layers/mixing/base.py
def compute_spectral_norm(self, tensor: torch.Tensor) -> torch.Tensor:
    """Compute spectral norm for analysis and regularization.

    Parameters
    ----------
    tensor : torch.Tensor
        Input tensor to compute spectral norm for.

    Returns
    -------
    torch.Tensor
        Spectral norm of the input tensor.
    """
    # Reshape to matrix for spectral norm computation
    batch_size, seq_len, hidden_dim = tensor.shape
    matrix = tensor.view(batch_size * seq_len, hidden_dim)

    # Compute singular values
    _, s, _ = torch.svd(matrix)

    # Return maximum singular value (spectral norm)
    return torch.max(s, dim=-1)[0].mean()

UnitaryMixingLayer

UnitaryMixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001)

Bases: MixingLayer

Base class for unitary mixing operations.

Unitary mixing layers preserve energy and inner products, maintaining mathematical properties essential for stable training and theoretical guarantees in spectral transformers.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4

Attributes:

Name Type Description
energy_tolerance float

Tolerance for energy preservation checks.

Methods:

Name Description
get_spectral_properties

Get properties specific to unitary transforms.

verify_energy_preservation

Verify energy preservation (Parseval's theorem).

verify_orthogonality

Verify orthogonality of the transform matrix.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
):
    super().__init__(hidden_dim, dropout, norm_eps)
    self.energy_tolerance = energy_tolerance
Functions
get_spectral_properties
get_spectral_properties() -> dict[str, Any]

Get properties specific to unitary transforms.

Returns:

Type Description
dict[str, Any]

Dictionary containing unitary transform properties.

Source code in spectrans/layers/mixing/base.py
def get_spectral_properties(self) -> dict[str, Any]:
    """Get properties specific to unitary transforms.

    Returns
    -------
    dict[str, Any]
        Dictionary containing unitary transform properties.
    """
    return {
        "unitary": True,
        "energy_preserving": True,
        "invertible": True,
        "orthogonal": True,
        "spectrum_preserving": True,
    }
verify_energy_preservation
verify_energy_preservation(input_tensor: Tensor, output_tensor: Tensor) -> bool

Verify energy preservation (Parseval's theorem).

Checks that \(||\mathbf{output}||^2 \approx ||\mathbf{input}||^2\) within tolerance.

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor before transformation.

required
output_tensor Tensor

Output tensor after transformation.

required

Returns:

Type Description
bool

True if energy is preserved within tolerance.

Source code in spectrans/layers/mixing/base.py
def verify_energy_preservation(
    self, input_tensor: torch.Tensor, output_tensor: torch.Tensor
) -> bool:
    r"""Verify energy preservation (Parseval's theorem).

    Checks that $||\mathbf{output}||^2 \approx ||\mathbf{input}||^2$ within tolerance.

    Parameters
    ----------
    input_tensor : torch.Tensor
        Input tensor before transformation.
    output_tensor : torch.Tensor
        Output tensor after transformation.

    Returns
    -------
    bool
        True if energy is preserved within tolerance.
    """
    input_energy = torch.norm(input_tensor, p=2, dim=-1) ** 2
    output_energy = torch.norm(output_tensor, p=2, dim=-1) ** 2

    energy_diff = torch.abs(input_energy - output_energy)
    max_energy = torch.max(input_energy, output_energy)

    # Relative error should be within tolerance
    relative_error = energy_diff / (max_energy + self.norm_eps)

    return bool(torch.all(relative_error < self.energy_tolerance))
verify_orthogonality
verify_orthogonality(transform_matrix: Tensor) -> bool

Verify orthogonality of the transform matrix.

Checks that \(\mathbf{T} \mathbf{T}^H \approx \mathbf{I}\) (identity matrix).

Parameters:

Name Type Description Default
transform_matrix Tensor

Transform matrix to verify.

required

Returns:

Type Description
bool

True if matrix is orthogonal within tolerance.

Source code in spectrans/layers/mixing/base.py
def verify_orthogonality(self, transform_matrix: torch.Tensor) -> bool:
    r"""Verify orthogonality of the transform matrix.

    Checks that $\mathbf{T} \mathbf{T}^H \approx \mathbf{I}$ (identity matrix).

    Parameters
    ----------
    transform_matrix : torch.Tensor
        Transform matrix to verify.

    Returns
    -------
    bool
        True if matrix is orthogonal within tolerance.
    """
    # Compute T @ T^H
    product = torch.matmul(transform_matrix, transform_matrix.conj().transpose(-2, -1))

    # Expected identity matrix
    identity = torch.eye(
        transform_matrix.size(-1), device=transform_matrix.device, dtype=transform_matrix.dtype
    )

    # Check deviation from identity
    deviation = torch.norm(product - identity, p="fro")

    return bool(deviation < self.energy_tolerance)

FourierMixing

FourierMixing(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)

Bases: UnitaryMixingLayer

FNet-style Fourier mixing layer.

Implements the core FNet mixing operation using 2D Fourier transforms along both sequence and feature dimensions, providing an alternative to attention with \(O(n \log n)\) complexity.

The operation performs: 1. 2D FFT across sequence and feature dimensions 2. Optional real part extraction for final output (original FNet behavior) or keep complex values for full information preservation

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
dropout float

Dropout probability applied after the mixing operation.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

Normalization mode for FFT operations ("forward", "backward", "ortho").

"ortho"
keep_complex bool

If True, keeps complex values from FFT. If False (default), takes only the real part as in original FNet.

False

Attributes:

Name Type Description
fft2d FFT2D

2D Fourier transform module.

keep_complex bool

Whether to keep complex values or extract real part.

Methods:

Name Description
forward

Apply Fourier mixing to input tensor.

get_spectral_properties

Get spectral properties of Fourier mixing.

from_config

Create FourierMixing layer from configuration.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
    keep_complex: bool = False,
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.keep_complex = keep_complex
    # Store transform as non-module attribute to avoid PyTorch module registration
    self.fft2d: FFT2D  # Type annotation for mypy
    object.__setattr__(self, "fft2d", FFT2D(norm=fft_norm))
Functions
forward
forward(x: Tensor) -> Tensor

Apply Fourier mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Mixed tensor of same shape. Complex if keep_complex=True, real values only if keep_complex=False (default).

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply Fourier mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Mixed tensor of same shape. Complex if keep_complex=True,
        real values only if keep_complex=False (default).
    """
    # Apply 2D FFT along last two dimensions (sequence and feature)
    x_freq = self.fft2d.transform(x, dim=(-2, -1))
    # Keep complex values for information preservation or take real part (default)
    x_mixed = x_freq if self.keep_complex else torch.real(x_freq)

    # Apply dropout
    x_mixed = self.dropout(x_mixed)

    return x_mixed  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get spectral properties of Fourier mixing.

Returns:

Type Description
dict[str, str | bool]

Properties including energy preservation and domain information.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get spectral properties of Fourier mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties including energy preservation and domain information.
    """
    return {
        "unitary": False,
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": False,
        "translation_equivariant": True,
        "learnable_parameters": False,
    }
from_config classmethod
from_config(config: FourierMixingConfig) -> FourierMixing

Create FourierMixing layer from configuration.

Parameters:

Name Type Description Default
config FourierMixingConfig

Configuration object with layer parameters.

required

Returns:

Type Description
FourierMixing

Configured Fourier mixing layer.

Source code in spectrans/layers/mixing/fourier.py
@classmethod
def from_config(cls, config: "FourierMixingConfig") -> "FourierMixing":
    """Create FourierMixing layer from configuration.

    Parameters
    ----------
    config : FourierMixingConfig
        Configuration object with layer parameters.

    Returns
    -------
    FourierMixing
        Configured Fourier mixing layer.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        dropout=config.dropout,
        norm_eps=config.norm_eps,
        energy_tolerance=config.energy_tolerance,
        fft_norm=config.fft_norm,
        keep_complex=config.keep_complex,
    )

FourierMixing1D

FourierMixing1D(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)

Bases: UnitaryMixingLayer

1D Fourier mixing along sequence dimension only.

Applies Fourier transform only along the sequence dimension, preserving feature dimension locality while mixing tokens.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
dropout float

Dropout probability applied after the mixing operation.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

Normalization mode for FFT operations.

"ortho"
keep_complex bool

If True, keeps complex values from FFT. If False (default), takes only the real part.

False

Attributes:

Name Type Description
fft1d FFT1D

1D Fourier transform module.

keep_complex bool

Whether to keep complex values or extract real part.

Methods:

Name Description
forward

Apply 1D Fourier mixing to input tensor.

get_spectral_properties

Get spectral properties of 1D Fourier mixing.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
    keep_complex: bool = False,
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.keep_complex = keep_complex
    # Store transform as non-module attribute to avoid PyTorch module registration
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))
Functions
forward
forward(x: Tensor) -> Tensor

Apply 1D Fourier mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Mixed tensor with Fourier transform applied along sequence dimension. Complex if keep_complex=True, real values only if keep_complex=False.

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply 1D Fourier mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Mixed tensor with Fourier transform applied along sequence dimension.
        Complex if keep_complex=True, real values only if keep_complex=False.
    """
    # Apply 1D FFT along sequence dimension only
    x_freq = self.fft1d.transform(x, dim=1)  # sequence dimension

    # Keep complex values or take real part (default behavior)
    x_mixed = x_freq if self.keep_complex else torch.real(x_freq)

    # Apply dropout
    x_mixed = self.dropout(x_mixed)

    return x_mixed  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get spectral properties of 1D Fourier mixing.

Returns:

Type Description
dict[str, str | bool]

Properties specific to 1D sequence mixing.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get spectral properties of 1D Fourier mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties specific to 1D sequence mixing.
    """
    return {
        "unitary": False,  # Real part extraction breaks unitarity
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": False,
        "sequence_mixing_only": True,
        "feature_preserving": True,
        "learnable_parameters": False,
    }

RealFourierMixing

RealFourierMixing(hidden_dim: int, use_real_fft: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')

Bases: UnitaryMixingLayer

Memory-efficient real Fourier mixing.

Uses real FFT operations to exploit Hermitian symmetry, providing ~2x memory and computational savings for real inputs.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
use_real_fft bool

Whether to use real FFT for efficiency.

True
dropout float

Dropout probability applied after mixing.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

Normalization mode for FFT operations.

"ortho"

Attributes:

Name Type Description
use_real_fft bool

Whether real FFT is enabled.

rfft RFFT

Real FFT transform for sequence dimension.

rfft2d RFFT2D

Real 2D FFT transform for both dimensions.

Methods:

Name Description
forward

Apply real Fourier mixing to input tensor.

get_spectral_properties

Get spectral properties of real Fourier mixing.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    use_real_fft: bool = True,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.use_real_fft = use_real_fft

    if use_real_fft:
        # Type annotations for mypy
        self.rfft: RFFT
        self.rfft2d: RFFT2D
        # Store transforms as non-module attributes
        object.__setattr__(self, "rfft", RFFT(norm=fft_norm))
        object.__setattr__(self, "rfft2d", RFFT2D(norm=fft_norm))
    else:
        # Type annotation for mypy
        self.fft2d: FFT2D
        # Fallback to complex FFT
        object.__setattr__(self, "fft2d", FFT2D(norm=fft_norm))
Functions
forward
forward(x: Tensor) -> Tensor

Apply real Fourier mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim). Should be real-valued for optimal efficiency.

required

Returns:

Type Description
Tensor

Mixed tensor, guaranteed to be real-valued.

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply real Fourier mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).
        Should be real-valued for optimal efficiency.

    Returns
    -------
    torch.Tensor
        Mixed tensor, guaranteed to be real-valued.
    """
    if self.use_real_fft and torch.is_floating_point(x):
        # Use real FFT for efficiency
        x_freq = self.rfft2d.transform(x, dim=(-2, -1))
        # Inverse RFFT automatically returns real values
        x_mixed = self.rfft2d.inverse_transform(x_freq, dim=(-2, -1))
    else:
        # Fallback to complex FFT with real part extraction
        x_freq = self.fft2d.transform(x, dim=(-2, -1))
        x_mixed = torch.real(x_freq)

    # Apply dropout
    x_mixed = self.dropout(x_mixed)

    return x_mixed  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get spectral properties of real Fourier mixing.

Returns:

Type Description
dict[str, str | bool]

Properties including efficiency characteristics.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get spectral properties of real Fourier mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties including efficiency characteristics.
    """
    return {
        "unitary": self.use_real_fft,  # Real FFT preserves unitarity
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": self.use_real_fft,
        "memory_efficient": self.use_real_fft,
        "hermitian_symmetry": self.use_real_fft,
        "learnable_parameters": False,
    }

SeparableFourierMixing

SeparableFourierMixing(hidden_dim: int, mix_features: bool = True, mix_sequence: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')

Bases: UnitaryMixingLayer

Separable Fourier mixing with sequence and feature transforms.

Applies separate 1D Fourier transforms along sequence and feature dimensions, which can be more efficient than 2D FFT for certain tensor shapes and provides different mixing characteristics.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
mix_features bool

Whether to apply FFT along feature dimension.

True
mix_sequence bool

Whether to apply FFT along sequence dimension.

True
dropout float

Dropout probability.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

FFT normalization mode.

"ortho"

Attributes:

Name Type Description
mix_features bool

Whether feature mixing is enabled.

mix_sequence bool

Whether sequence mixing is enabled.

fft1d FFT1D

1D FFT transform module.

Methods:

Name Description
forward

Apply separable Fourier mixing.

get_spectral_properties

Get properties of separable mixing.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    mix_features: bool = True,
    mix_sequence: bool = True,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.mix_features = mix_features
    self.mix_sequence = mix_sequence
    # Store transform as non-module attribute
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))

    if not mix_features and not mix_sequence:
        raise ValueError("At least one of mix_features or mix_sequence must be True")
Functions
forward
forward(x: Tensor) -> Tensor

Apply separable Fourier mixing.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Mixed tensor after applying selected transforms.

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply separable Fourier mixing.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Mixed tensor after applying selected transforms.
    """
    # Apply sequence mixing (along dim=1)
    if self.mix_sequence:
        x_freq_seq = self.fft1d.transform(x, dim=1)
        x = torch.real(x_freq_seq)

    # Apply feature mixing (along dim=2)
    if self.mix_features:
        x_freq_feat = self.fft1d.transform(x, dim=2)
        x = torch.real(x_freq_feat)

    # Apply dropout
    x = self.dropout(x)

    return x
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get properties of separable mixing.

Returns:

Type Description
dict[str, str | bool]

Properties reflecting the separable nature.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get properties of separable mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties reflecting the separable nature.
    """
    return {
        "unitary": False,  # Real part extraction
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": False,
        "separable": True,
        "sequence_mixing": self.mix_sequence,
        "feature_mixing": self.mix_features,
        "learnable_parameters": False,
    }

AdaptiveGlobalFilter

AdaptiveGlobalFilter(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02, filter_regularization: float = 0.0, adaptive_initialization: bool = True, spectral_dropout_p: float = 0.0)

Bases: FilterMixingLayer

Adaptive Global Filter with regularization and smart initialization.

Enhanced version of global filtering with adaptive initialization strategies, regularization options, and improved training stability.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of input tensors.

required
sequence_length int

Expected sequence length.

required
activation ActivationType

Filter activation function.

"sigmoid"
dropout float

Dropout probability.

0.0
norm_eps float

Numerical stability epsilon.

1e-5
learnable_filters bool

Whether filters are learnable.

True
fft_norm str

FFT normalization.

"ortho"
filter_init_std float

Filter initialization standard deviation.

0.02
filter_regularization float

L2 regularization strength for filter parameters.

0.0
adaptive_initialization bool

Whether to use frequency-aware initialization.

True
spectral_dropout_p float

Spectral dropout probability in frequency domain.

0.0

Attributes:

Name Type Description
filter_regularization float

Regularization strength.

adaptive_initialization bool

Whether adaptive initialization is used.

spectral_dropout_p float

Spectral dropout probability.

spectral_dropout Module

Spectral dropout layer.

Methods:

Name Description
forward

Apply adaptive global filtering.

get_filter_response

Get adaptive frequency response.

get_regularization_loss

Compute L2 regularization loss for filter parameters.

get_spectral_properties

Get adaptive filter properties.

Source code in spectrans/layers/mixing/global_filter.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    activation: ActivationType = "sigmoid",
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
    fft_norm: FFTNorm = "ortho",
    filter_init_std: float = 0.02,
    filter_regularization: float = 0.0,
    adaptive_initialization: bool = True,
    spectral_dropout_p: float = 0.0,
):
    super().__init__(hidden_dim, sequence_length, dropout, norm_eps, learnable_filters)
    self.activation = activation
    self.filter_regularization = filter_regularization
    self.adaptive_initialization = adaptive_initialization
    self.spectral_dropout_p = spectral_dropout_p

    # Initialize filter parameters
    if adaptive_initialization:
        # Frequency-aware initialization: smaller values for high frequencies
        frequencies = torch.fft.fftfreq(sequence_length)
        # Weight by inverse frequency (avoiding DC component)
        freq_weights = 1.0 / (torch.abs(frequencies) + 0.1)
        freq_weights = freq_weights / freq_weights.mean()

        self.filter_real = nn.Parameter(
            torch.randn(sequence_length, hidden_dim)
            * filter_init_std
            * freq_weights.unsqueeze(-1)
        )
        self.filter_imag = nn.Parameter(
            torch.randn(sequence_length, hidden_dim)
            * filter_init_std
            * freq_weights.unsqueeze(-1)
        )
    else:
        # Standard initialization
        self.filter_real = nn.Parameter(
            torch.randn(sequence_length, hidden_dim) * filter_init_std
        )
        self.filter_imag = nn.Parameter(
            torch.randn(sequence_length, hidden_dim) * filter_init_std
        )

    # Store FFT transform as non-module attribute
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))

    self.activation_fn: Callable[[Tensor], Tensor]
    if activation == "sigmoid":
        self.activation_fn = nn.Sigmoid()
    elif activation == "tanh":
        self.activation_fn = nn.Tanh()
    elif activation == "identity":
        self.activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unknown activation: {activation}")

    # Spectral dropout for regularization
    if spectral_dropout_p > 0:
        self.spectral_dropout: nn.Module = nn.Dropout2d(spectral_dropout_p)
    else:
        self.spectral_dropout = nn.Identity()
Functions
forward
forward(x: Tensor) -> Tensor

Apply adaptive global filtering.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Adaptively filtered tensor.

Source code in spectrans/layers/mixing/global_filter.py
def forward(self, x: Tensor) -> Tensor:
    """Apply adaptive global filtering.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    Tensor
        Adaptively filtered tensor.
    """
    # Transform to frequency domain
    x_freq = self.fft1d.transform(x, dim=1)

    # Get actual sequence length
    seq_len = x_freq.shape[1]

    # Adapt filter to actual sequence length using interpolation
    if seq_len != self.sequence_length:
        # Use interpolation to adapt filters to the actual sequence length
        filter_real = (
            nn.functional.interpolate(
                self.filter_real.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)

        filter_imag = (
            nn.functional.interpolate(
                self.filter_imag.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)
    else:
        filter_real = self.filter_real
        filter_imag = self.filter_imag

    # Create complex filter with activation
    filter_complex = make_complex(
        self.activation_fn(filter_real), self.activation_fn(filter_imag)
    )

    # Apply spectral dropout to filter (not input)
    if self.training and self.spectral_dropout_p > 0:
        filter_complex = self.spectral_dropout(filter_complex)

    # Apply filtering in frequency domain
    filtered_freq = complex_multiply(x_freq, filter_complex)

    # Transform back to time domain
    filtered_time = self.fft1d.inverse_transform(filtered_freq, dim=1)

    # Take real part
    output = torch.real(filtered_time)

    # Apply standard dropout
    output = self.dropout(output)

    return output  # type: ignore[no-any-return]
get_filter_response
get_filter_response() -> Tensor

Get adaptive frequency response.

Returns:

Type Description
Tensor

Complex frequency response with current parameters.

Source code in spectrans/layers/mixing/global_filter.py
def get_filter_response(self) -> Tensor:
    """Get adaptive frequency response.

    Returns
    -------
    Tensor
        Complex frequency response with current parameters.
    """
    return make_complex(
        self.activation_fn(self.filter_real), self.activation_fn(self.filter_imag)
    )
get_regularization_loss
get_regularization_loss() -> Tensor

Compute L2 regularization loss for filter parameters.

Returns:

Type Description
Tensor

Scalar regularization loss.

Source code in spectrans/layers/mixing/global_filter.py
def get_regularization_loss(self) -> Tensor:
    """Compute L2 regularization loss for filter parameters.

    Returns
    -------
    Tensor
        Scalar regularization loss.
    """
    if self.filter_regularization <= 0:
        return torch.tensor(0.0, device=self.filter_real.device)

    real_loss = torch.norm(self.filter_real, p=2) ** 2
    imag_loss = torch.norm(self.filter_imag, p=2) ** 2

    return self.filter_regularization * (real_loss + imag_loss)  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool | int]

Get adaptive filter properties.

Returns:

Type Description
dict[str, str | bool | int]

Comprehensive properties including adaptive features.

Source code in spectrans/layers/mixing/global_filter.py
def get_spectral_properties(self) -> dict[str, str | bool | int]:
    """Get adaptive filter properties.

    Returns
    -------
    dict[str, str | bool | int]
        Comprehensive properties including adaptive features.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": True,
        "complex_valued": True,
        "selective_filtering": True,
        "energy_preserving": False,
        "adaptive_initialization": self.adaptive_initialization,
        "regularization": self.filter_regularization > 0,
        "spectral_dropout": self.spectral_dropout_p > 0,
        "activation": self.activation,
        "parameter_count": 2 * self.sequence_length * self.hidden_dim,
    }

GlobalFilterMixing

GlobalFilterMixing(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)

Bases: FilterMixingLayer

Global Filter Network mixing layer.

Implements the core GFNet mixing operation with learnable complex filters applied in the frequency domain along the sequence dimension.

The layer uses interpolation to adapt filters to different sequence lengths, processing variable-length inputs while preserving learned frequency patterns. This provides resolution independence compared to fixed-size filtering.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of input tensors.

required
sequence_length int

Base sequence length for filter parameter initialization. The filters will be interpolated to match actual input sequence lengths.

required
activation ActivationType

Activation function applied to filter parameters ("sigmoid", "tanh", "identity").

"sigmoid"
dropout float

Dropout probability applied after filtering.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filter parameters are learnable (always True for this class).

True
fft_norm str

FFT normalization mode.

"ortho"
filter_init_std float

Standard deviation for filter parameter initialization.

0.02

Attributes:

Name Type Description
activation str

Activation function name.

filter_real Parameter

Real part of complex filter parameters.

filter_imag Parameter

Imaginary part of complex filter parameters.

fft1d FFT1D

1D FFT transform for sequence dimension.

activation_fn Module

Activation function module (Sigmoid, Tanh, or Identity).

Methods:

Name Description
forward

Apply global filtering to input tensor.

get_filter_response

Get the current frequency response of the filters.

get_spectral_properties

Get spectral properties of global filtering.

from_config

Create GlobalFilterMixing layer from configuration.

Source code in spectrans/layers/mixing/global_filter.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    activation: ActivationType = "sigmoid",
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
    fft_norm: FFTNorm = "ortho",
    filter_init_std: float = 0.02,
):
    super().__init__(hidden_dim, sequence_length, dropout, norm_eps, learnable_filters)
    self.activation = activation

    # Initialize complex filter parameters
    self.filter_real = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)
    self.filter_imag = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)

    # Store FFT transform as non-module attribute
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))

    # Activation function
    if activation == "sigmoid":
        self.activation_fn: Callable[[Tensor], Tensor] = nn.Sigmoid()
    elif activation == "tanh":
        self.activation_fn = nn.Tanh()
    elif activation == "identity":
        self.activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unknown activation: {activation}")
Functions
forward
forward(x: Tensor) -> Tensor

Apply global filtering to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Filtered tensor of same shape as input.

Source code in spectrans/layers/mixing/global_filter.py
def forward(self, x: Tensor) -> Tensor:
    """Apply global filtering to input tensor.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    Tensor
        Filtered tensor of same shape as input.
    """
    # Transform to frequency domain
    x_freq = self.fft1d.transform(x, dim=1)  # Along sequence dimension

    # Get actual sequence length
    seq_len = x_freq.shape[1]

    # Adapt filter to actual sequence length using interpolation
    if seq_len != self.sequence_length:
        # Use interpolation to adapt filters to the actual sequence length
        # This preserves the learned frequency patterns at different resolutions
        filter_real = (
            nn.functional.interpolate(
                self.filter_real.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)

        filter_imag = (
            nn.functional.interpolate(
                self.filter_imag.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)
    else:
        filter_real = self.filter_real
        filter_imag = self.filter_imag

    # Create complex filter
    filter_complex = make_complex(
        self.activation_fn(filter_real), self.activation_fn(filter_imag)
    )

    # Apply filter in frequency domain (element-wise multiplication)
    filtered_freq = complex_multiply(x_freq, filter_complex)

    # Transform back to time domain
    filtered_time = self.fft1d.inverse_transform(filtered_freq, dim=1)

    # Take real part (assuming real-valued output is desired)
    output = torch.real(filtered_time)

    # Apply dropout
    output = self.dropout(output)

    return output  # type: ignore[no-any-return]
get_filter_response
get_filter_response() -> Tensor

Get the current frequency response of the filters.

Returns:

Type Description
Tensor

Complex-valued frequency response of shape (sequence_length, hidden_dim).

Source code in spectrans/layers/mixing/global_filter.py
def get_filter_response(self) -> Tensor:
    """Get the current frequency response of the filters.

    Returns
    -------
    Tensor
        Complex-valued frequency response of shape (sequence_length, hidden_dim).
    """
    return make_complex(
        self.activation_fn(self.filter_real), self.activation_fn(self.filter_imag)
    )
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool | int]

Get spectral properties of global filtering.

Returns:

Type Description
dict[str, str | bool | int]

Properties including filter characteristics.

Source code in spectrans/layers/mixing/global_filter.py
def get_spectral_properties(self) -> dict[str, str | bool | int]:
    """Get spectral properties of global filtering.

    Returns
    -------
    dict[str, str | bool | int]
        Properties including filter characteristics.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": True,
        "complex_valued": True,
        "selective_filtering": True,
        "energy_preserving": False,  # Filtering can change energy
        "activation": self.activation,
        "parameter_count": 2 * self.sequence_length * self.hidden_dim,
    }
from_config classmethod

Create GlobalFilterMixing layer from configuration.

Parameters:

Name Type Description Default
config GlobalFilterMixingConfig

Configuration object with layer parameters.

required

Returns:

Type Description
GlobalFilterMixing

Configured global filter mixing layer.

Source code in spectrans/layers/mixing/global_filter.py
@classmethod
def from_config(cls, config: "GlobalFilterMixingConfig") -> "GlobalFilterMixing":
    """Create GlobalFilterMixing layer from configuration.

    Parameters
    ----------
    config : GlobalFilterMixingConfig
        Configuration object with layer parameters.

    Returns
    -------
    GlobalFilterMixing
        Configured global filter mixing layer.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        sequence_length=config.sequence_length,
        activation=config.activation,
        dropout=config.dropout,
        learnable_filters=config.learnable_filters,
        fft_norm=config.fft_norm,
        filter_init_std=config.filter_init_std,
    )

GlobalFilterMixing2D

GlobalFilterMixing2D(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)

Bases: FilterMixingLayer

2D Global Filter mixing with filtering along both dimensions.

Extends global filtering to both sequence and feature dimensions, similar to FNet's 2D FFT but with learnable filters.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of input tensors.

required
sequence_length int

Expected sequence length.

required
activation ActivationType

Activation function for filter parameters.

"sigmoid"
dropout float

Dropout probability.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filters are learnable.

True
fft_norm str

FFT normalization mode.

"ortho"
filter_init_std float

Filter parameter initialization standard deviation.

0.02

Attributes:

Name Type Description
filter_real Parameter

Real part of 2D complex filters.

filter_imag Parameter

Imaginary part of 2D complex filters.

fft2d FFT2D

2D FFT transform module.

activation_fn Module

Activation function.

Methods:

Name Description
forward

Apply 2D global filtering.

get_filter_response

Get 2D frequency response.

get_spectral_properties

Get 2D filter properties.

Source code in spectrans/layers/mixing/global_filter.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    activation: ActivationType = "sigmoid",
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
    fft_norm: FFTNorm = "ortho",
    filter_init_std: float = 0.02,
):
    super().__init__(hidden_dim, sequence_length, dropout, norm_eps, learnable_filters)
    self.activation = activation

    # Initialize 2D complex filter parameters
    self.filter_real = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)
    self.filter_imag = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)

    # Store 2D FFT transform as non-module attribute
    self.fft2d: FFT2D  # Type annotation for mypy
    object.__setattr__(self, "fft2d", FFT2D(norm=fft_norm))

    # Activation function
    if activation == "sigmoid":
        self.activation_fn: Callable[[Tensor], Tensor] = nn.Sigmoid()
    elif activation == "tanh":
        self.activation_fn = nn.Tanh()
    elif activation == "identity":
        self.activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unknown activation: {activation}")
Functions
forward
forward(x: Tensor) -> Tensor

Apply 2D global filtering.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Filtered tensor of same shape.

Source code in spectrans/layers/mixing/global_filter.py
def forward(self, x: Tensor) -> Tensor:
    """Apply 2D global filtering.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    Tensor
        Filtered tensor of same shape.
    """
    # Transform to 2D frequency domain
    x_freq = self.fft2d.transform(x, dim=(-2, -1))

    # Get actual dimensions
    seq_len = x_freq.shape[-2]
    hidden = x_freq.shape[-1]

    # Adapt filter to actual dimensions using bilinear interpolation
    if seq_len != self.sequence_length or hidden != self.hidden_dim:
        # Reshape for 2D interpolation
        filter_real = (
            nn.functional.interpolate(
                self.filter_real.unsqueeze(0).unsqueeze(0),  # (1, 1, seq_length, hidden_dim)
                size=(seq_len, hidden),
                mode="bilinear",
                align_corners=False,
            )
            .squeeze(0)
            .squeeze(0)
        )  # (seq_len, hidden)

        filter_imag = (
            nn.functional.interpolate(
                self.filter_imag.unsqueeze(0).unsqueeze(0),  # (1, 1, seq_length, hidden_dim)
                size=(seq_len, hidden),
                mode="bilinear",
                align_corners=False,
            )
            .squeeze(0)
            .squeeze(0)
        )  # (seq_len, hidden)
    else:
        filter_real = self.filter_real
        filter_imag = self.filter_imag

    # Create complex filter
    filter_complex = make_complex(
        self.activation_fn(filter_real), self.activation_fn(filter_imag)
    )

    # Apply 2D filter
    filtered_freq = complex_multiply(x_freq, filter_complex)

    # Transform back to spatial domain
    filtered_spatial = self.fft2d.inverse_transform(filtered_freq, dim=(-2, -1))

    # Take real part
    output = torch.real(filtered_spatial)

    # Apply dropout
    output = self.dropout(output)

    return output  # type: ignore[no-any-return]
get_filter_response
get_filter_response() -> Tensor

Get 2D frequency response.

Returns:

Type Description
Tensor

Complex 2D frequency response.

Source code in spectrans/layers/mixing/global_filter.py
def get_filter_response(self) -> Tensor:
    """Get 2D frequency response.

    Returns
    -------
    Tensor
        Complex 2D frequency response.
    """
    return make_complex(
        self.activation_fn(self.filter_real), self.activation_fn(self.filter_imag)
    )
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool | int]

Get 2D filter properties.

Returns:

Type Description
dict[str, str | bool | int]

2D filtering characteristics.

Source code in spectrans/layers/mixing/global_filter.py
def get_spectral_properties(self) -> dict[str, str | bool | int]:
    """Get 2D filter properties.

    Returns
    -------
    dict[str, str | bool | int]
        2D filtering characteristics.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": True,
        "complex_valued": True,
        "selective_filtering": True,
        "energy_preserving": False,
        "two_dimensional": True,
        "activation": self.activation,
        "parameter_count": 2 * self.sequence_length * self.hidden_dim,
    }

WaveletMixing

WaveletMixing(hidden_dim: int, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', dropout: float = 0.0)

Bases: Module

Token mixing layer using discrete wavelet transform.

Performs mixing in wavelet domain for multi-resolution processing. Decomposes input using DWT, applies learnable mixing to coefficients, and reconstructs the output with residual connections.

Mathematical Formulation

Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times D}\) where \(B\) is batch size, \(N\) is sequence length, and \(D\) is hidden dimension:

Step 1: Channel-wise Decomposition

For each channel \(d \in \{0, 1, \ldots, D-1\}\), extract the channel signal:

\[ \mathbf{x}^{(d)} = \mathbf{X}[:, :, d] \in \mathbb{R}^{B \times N} \]

Apply \(J\)-level DWT decomposition:

\[ \text{DWT}_J(\mathbf{x}^{(d)}) = \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}_{j=1}^J\} \]

Where: - \(\mathbf{c}_{A_J}^{(d)} \in \mathbb{R}^{B \times L_{A_J}}\) are approximation coefficients at level \(J\) - \(\mathbf{c}_{D_j}^{(d)} \in \mathbb{R}^{B \times L_{D_j}}\) are detail coefficients at level \(j\) - \(L_{A_J}\) and \(L_{D_j}\) are coefficient lengths after subsampling

Step 2: Learnable Mixing

Apply mixing transformations based on mode:

Pointwise Mixing (:code:mixing_mode='pointwise'):

\[ \tilde{\mathbf{c}}_{A_J}^{(d)} = \mathbf{c}_{A_J}^{(d)} \odot \mathbf{W}_{A}[:, :L_{A_J}, d] \]
\[ \tilde{\mathbf{c}}_{D_j}^{(d)} = \mathbf{c}_{D_j}^{(d)} \odot \mathbf{W}_{D_j}[:, :L_{D_j}, d] \]

Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times \max(L) \times D}\) are learnable parameters, and \(\odot\) denotes element-wise multiplication with broadcasting.

Channel Mixing (:code:mixing_mode='channel'):

\[ \tilde{\mathbf{c}}_{A_J}^{(d)} = \mathbf{c}_{A_J}^{(d)} \cdot \mathbf{W}_{A}[0, d, d] \]
\[ \tilde{\mathbf{c}}_{D_j}^{(d)} = \mathbf{c}_{D_j}^{(d)} \cdot \mathbf{W}_{D_j}[0, d, d] \]

Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times D \times D}\) are initialized as identity matrices.

Level Mixing (:code:mixing_mode='level'):

Cross-level attention is applied to all coefficients simultaneously:

\[ \{\tilde{\mathbf{c}}_{A_J}^{(d)}, \{\tilde{\mathbf{c}}_{D_j}^{(d)}\}_{j=1}^J\} = \text{MultiHeadAttn}(\text{Concat}(\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\})) \]

Step 3: Reconstruction

Reconstruct the signal using inverse DWT:

\[ \tilde{\mathbf{x}}^{(d)} = \text{IDWT}_J(\{\tilde{\mathbf{c}}_{A_J}^{(d)}, \{\tilde{\mathbf{c}}_{D_j}^{(d)}\}_{j=1}^J\}) \]

Apply length adjustment if necessary:

\[ \hat{\mathbf{x}}^{(d)} = \begin{cases} \tilde{\mathbf{x}}^{(d)}[:, :N] & \text{if } |\tilde{\mathbf{x}}^{(d)}| > N \\ \text{Pad}(\tilde{\mathbf{x}}^{(d)}, N) & \text{if } |\tilde{\mathbf{x}}^{(d)}| < N \\ \tilde{\mathbf{x}}^{(d)} & \text{otherwise} \end{cases} \]

Step 4: Residual Connection and Dropout

Combine all channels and apply residual connection:

\[ \hat{\mathbf{X}} = \text{Concat}(\{\hat{\mathbf{x}}^{(d)}\}_{d=0}^{D-1}) \in \mathbb{R}^{B \times N \times D} \]
\[ \mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}}) \]
Complexity Analysis
  • Time Complexity: \(O(NJ) + O(D \cdot N \log N)\) per forward pass

    • \(O(N)\) for DWT/IDWT per level and channel (linear in signal length)
    • \(O(DJ)\) for mixing operations across all levels and channels
    • Dominated by DWT operations when \(J\) is small
  • Space Complexity: \(O(DN + P)\) where \(P\) is parameter count

    • \(O(DN)\) for storing coefficient tensors
    • Parameter count depends on mixing mode:
      • Pointwise: \(P = O(LD)\) where \(L\) is max coefficient length
      • Channel: \(P = O(JD^2)\)
      • Level: \(P = O(D^2)\) for attention parameters
Implementation Notes
  • Uses PyTorch-native DWT implementation for gradient compatibility
  • Dynamic weight slicing ensures proper alignment with variable-length coefficients
  • Perfect reconstruction property maintained through careful length handling
  • Each channel processed independently for computational efficiency

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension size \(D\).

required
wavelet str

Wavelet type (e.g., 'db1', 'db4', 'sym2'). Determines filter bank characteristics.

'db4'
levels int

Number of decomposition levels \(J\). Controls resolution hierarchy.

3
mixing_mode str

Mixing strategy: 'pointwise' (element-wise), 'channel' (diagonal), 'level' (attention).

'pointwise'
dropout float

Dropout probability applied to mixed coefficients before residual connection.

0.0

Attributes:

Name Type Description
dwt DWT1D

Wavelet transform module implementing PyTorch-native DWT/IDWT.

mixing_weights ParameterDict

Learnable parameters for coefficient mixing, structure depends on :attr:mixing_mode.

dropout Dropout

Dropout layer for regularization.

Raises:

Type Description
ValueError

If :attr:mixing_mode is not one of {'pointwise', 'channel', 'level'}.

Examples:

Basic usage with pointwise mixing:

>>> mixer = WaveletMixing(hidden_dim=256, wavelet='db4', levels=3)
>>> x = torch.randn(32, 128, 256)  # (batch, seq_len, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape

Channel mixing with identity initialization:

>>> mixer = WaveletMixing(hidden_dim=64, mixing_mode='channel', levels=2)
>>> x = torch.randn(16, 64, 64)
>>> output = mixer(x)
>>> # Initially behaves like identity due to residual connection

Cross-level mixing with attention:

>>> mixer = WaveletMixing(hidden_dim=128, mixing_mode='level', levels=4)
>>> x = torch.randn(8, 256, 128)
>>> output = mixer(x)  # Attention applied across wavelet levels

Methods:

Name Description
forward

Apply wavelet-based mixing following the mathematical formulation.

from_config

Create WaveletMixing from configuration.

Source code in spectrans/layers/mixing/wavelet.py
def __init__(
    self,
    hidden_dim: int,
    wavelet: WaveletType = "db4",
    levels: int = 3,
    mixing_mode: str = "pointwise",
    dropout: float = 0.0,
):
    super().__init__()

    self.hidden_dim = hidden_dim
    self.wavelet = wavelet
    self.levels = levels
    self.mixing_mode = mixing_mode

    # Initialize wavelet transform
    self.dwt = DWT1D(wavelet=wavelet, levels=levels, mode="symmetric")

    # Initialize mixing weights based on mode
    self.mixing_weights = nn.ParameterDict()

    if mixing_mode == "pointwise":
        # Simple pointwise multiplication for each level
        self.mixing_weights["approx"] = nn.Parameter(torch.ones(1, 1, hidden_dim))
        for level in range(levels):
            self.mixing_weights[f"detail_{level}"] = nn.Parameter(torch.ones(1, 1, hidden_dim))

    elif mixing_mode == "channel":
        # Channel-wise mixing matrices
        self.mixing_weights["approx"] = nn.Parameter(torch.eye(hidden_dim).unsqueeze(0))
        for level in range(levels):
            self.mixing_weights[f"detail_{level}"] = nn.Parameter(
                torch.eye(hidden_dim).unsqueeze(0)
            )

    elif mixing_mode == "level":
        # Cross-level mixing with attention-like mechanism
        # Use 1 as embedding dim since we process each channel independently
        self.level_mixer = nn.MultiheadAttention(
            1,
            num_heads=1,
            dropout=dropout,
            batch_first=True,  # Feature dim=1, so only 1 head possible
        )
    else:
        raise ValueError(f"Unknown mixing mode: {mixing_mode}")

    self.dropout = nn.Dropout(dropout)
Functions
forward
forward(x: Tensor) -> Tensor

Apply wavelet-based mixing following the mathematical formulation.

Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual. Each hidden dimension is processed independently to maintain channel separability.

Mathematical Implementation

The forward pass implements the mathematical formulation exactly:

  1. Channel Extraction: \(\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]\) for \(d = 0, \ldots, D-1\)
  2. Wavelet Decomposition: \(\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}\)
  3. Learnable Mixing: Apply mode-specific transformations to coefficients
  4. Signal Reconstruction: \(\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}\)
  5. Channel Concatenation: \(\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]\)
  6. Residual Connection: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape \((B, N, D)\) where:

  • \(B\) is batch size
  • \(N\) is sequence length
  • \(D\) is hidden dimension
required

Returns:

Type Description
Tensor

Mixed output tensor of identical shape \((B, N, D)\) with wavelet-domain mixing applied and residual connection.

Notes
  • Dynamic coefficient length handling ensures robustness to varying sequence lengths
  • Perfect reconstruction property maintained through careful padding/truncation
  • Gradient flow preserved through PyTorch-native operations
Source code in spectrans/layers/mixing/wavelet.py
def forward(self, x: Tensor) -> Tensor:
    r"""Apply wavelet-based mixing following the mathematical formulation.

    Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual.
    Each hidden dimension is processed independently to maintain channel separability.

    Mathematical Implementation
    ---------------------------
    The forward pass implements the mathematical formulation exactly:

    1. **Channel Extraction**: $\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]$ for $d = 0, \ldots, D-1$
    2. **Wavelet Decomposition**: $\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}$
    3. **Learnable Mixing**: Apply mode-specific transformations to coefficients
    4. **Signal Reconstruction**: $\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}$
    5. **Channel Concatenation**: $\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]$
    6. **Residual Connection**: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})

    Parameters
    ----------
    x : Tensor
        Input tensor of shape $(B, N, D)$ where:

        - $B$ is batch size
        - $N$ is sequence length
        - $D$ is hidden dimension

    Returns
    -------
    Tensor
        Mixed output tensor of identical shape $(B, N, D)$ with wavelet-domain
        mixing applied and residual connection.

    Notes
    -----
    - Dynamic coefficient length handling ensures robustness to varying sequence lengths
    - Perfect reconstruction property maintained through careful padding/truncation
    - Gradient flow preserved through PyTorch-native operations
    """
    _, seq_len, hidden_dim = x.shape

    # Store original input for residual connection
    residual = x

    # Process each hidden dimension independently
    outputs = []
    for h in range(hidden_dim):
        # Extract single channel and squeeze to 2D for DWT
        x_channel = x[:, :, h]  # Shape: [batch, seq_len]

        # Decompose using DWT
        approx, details = self.dwt.decompose(x_channel, dim=-1)

        # Apply mixing based on mode
        if self.mixing_mode == "pointwise":
            # Apply pointwise scaling - need to handle the shape correctly
            # approx shape is [batch, approx_len], weight needs to match
            approx_len = approx.shape[-1]
            approx_weight = self.mixing_weights["approx"][:, :approx_len, h]
            approx_mixed = approx * approx_weight

            details_mixed = []
            for level, detail in enumerate(details):
                detail_len = detail.shape[-1]
                weight = self.mixing_weights[f"detail_{level}"][:, :detail_len, h]
                details_mixed.append(detail * weight)

        elif self.mixing_mode == "channel":
            # Apply channel mixing (simplified for single channel processing)
            approx_mixed = approx * self.mixing_weights["approx"][:, h, h]
            details_mixed = []
            for level, detail in enumerate(details):
                weight = self.mixing_weights[f"detail_{level}"][:, h, h]
                details_mixed.append(detail * weight)

        elif self.mixing_mode == "level":
            # Stack all coefficients for cross-level mixing
            all_coeffs = [approx, *details]
            max_len = max(c.shape[-1] for c in all_coeffs)  # Use -1 for last dimension

            # Pad to same length
            padded_coeffs = []
            for coeff in all_coeffs:
                if coeff.shape[-1] < max_len:  # Use -1 to work with last dimension
                    pad_len = max_len - coeff.shape[-1]
                    coeff = F.pad(coeff, (0, pad_len))  # Pad the last dimension
                padded_coeffs.append(coeff)

            # Stack and apply attention
            stacked = torch.stack(padded_coeffs, dim=1)  # (batch, levels+1, max_len)

            # Reshape for attention: (batch * (levels+1), max_len) -> (batch * (levels+1), max_len, 1) for attention
            batch_size_coeff = stacked.shape[0]
            num_levels = stacked.shape[1]
            seq_len_coeff = stacked.shape[2]

            # Flatten batch and levels, then add feature dimension
            stacked_flat = stacked.view(
                batch_size_coeff * num_levels, seq_len_coeff, 1
            )  # (batch * levels, seq_len, 1)

            # Apply self-attention across sequence positions for each level independently
            mixed_flat, _ = self.level_mixer(stacked_flat, stacked_flat, stacked_flat)

            # Reshape back to separate batch and levels
            mixed = mixed_flat.view(
                batch_size_coeff, num_levels, seq_len_coeff, 1
            )  # Feature dim is 1, not hidden_dim

            # Extract mixed coefficients
            approx_mixed = mixed[
                :, 0, : approx.shape[-1], 0
            ]  # Extract approx coeffs for current channel
            details_mixed = []
            for level in range(self.levels):
                detail_len = details[level].shape[-1]
                detail_mixed = mixed[
                    :, level + 1, :detail_len, 0
                ]  # Extract detail coeffs for current channel
                details_mixed.append(detail_mixed)

        # Reconstruct signal
        reconstructed = self.dwt.reconstruct((approx_mixed, details_mixed), dim=-1)

        # Ensure output has correct length
        if reconstructed.shape[-1] != seq_len:
            if reconstructed.shape[-1] > seq_len:
                reconstructed = reconstructed[:, :seq_len]
            else:
                # Pad if needed
                pad_len = seq_len - reconstructed.shape[-1]
                reconstructed = F.pad(reconstructed, (0, pad_len))

        outputs.append(reconstructed.unsqueeze(-1))  # Add channel dim back

    # Combine all channels
    output = torch.cat(outputs, dim=-1)

    # Apply dropout and residual connection
    output = self.dropout(output)
    output = output + residual

    result: Tensor = output
    return result
from_config classmethod
from_config(config: WaveletMixingConfig) -> WaveletMixing

Create WaveletMixing from configuration.

Parameters:

Name Type Description Default
config WaveletMixingConfig

Typed and validated configuration.

required

Returns:

Type Description
WaveletMixing

Configured instance.

Source code in spectrans/layers/mixing/wavelet.py
@classmethod
def from_config(cls, config: "WaveletMixingConfig") -> "WaveletMixing":
    """Create WaveletMixing from configuration.

    Parameters
    ----------
    config : WaveletMixingConfig
        Typed and validated configuration.

    Returns
    -------
    WaveletMixing
        Configured instance.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        wavelet=config.wavelet,
        levels=config.levels,
        mixing_mode=config.mixing_mode,
        dropout=config.dropout,
    )

WaveletMixing2D

WaveletMixing2D(channels: int, wavelet: WaveletType = 'db4', levels: int = 2, mixing_mode: str = 'subband')

Bases: Module

2D wavelet mixing layer for image-like data.

Performs mixing in 2D wavelet domain, suitable for vision transformers and other architectures processing 2D spatial data. Processes spatial information through multi-resolution wavelet subbands.

Mathematical Formulation

Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times C \times H \times W}\) where \(B\) is batch size, \(C\) is channels, \(H\) is height, and \(W\) is width:

Step 1: Channel-wise 2D Decomposition

For each channel \(c \in \{0, 1, \ldots, C-1\}\), extract spatial data:

\[ \mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :] \in \mathbb{R}^{B \times H \times W} \]

Apply \(J\)-level 2D DWT decomposition:

\[ \text{DWT2D}_J(\mathbf{X}^{(c)}) = \{\mathbf{LL}_J^{(c)}, \{(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)})\}_{j=1}^J\} \]

Where: - \(\mathbf{LL}_J^{(c)} \in \mathbb{R}^{B \times H_J \times W_J}\) is the approximation subband (low-low) - \(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)} \in \mathbb{R}^{B \times H_j \times W_j}\) are detail subbands - \(H_j = \frac{H}{2^j}\), \(W_j = \frac{W}{2^j}\) are spatial dimensions at level \(j\)

Step 2: Subband Mixing

Apply mixing transformations based on mode:

Subband Mixing (:code:mixing_mode='subband'):

Independent processing of each subband using convolutional networks:

\[ \tilde{\mathbf{LL}}_J^{(c)} = f_{LL}(\mathbf{LL}_J^{(c)}) \]
\[ \tilde{\mathbf{LH}}_j^{(c)} = f_{LH}^{(j)}(\mathbf{LH}_j^{(c)}), \quad \tilde{\mathbf{HL}}_j^{(c)} = f_{HL}^{(j)}(\mathbf{HL}_j^{(c)}), \quad \tilde{\mathbf{HH}}_j^{(c)} = f_{HH}^{(j)}(\mathbf{HH}_j^{(c)}) \]

Where \(f_{\cdot}\) are learnable convolutional transformations.

Cross Mixing (:code:mixing_mode='cross'):

Cross-attention across all subbands:

\[ \{\tilde{\mathbf{LL}}_J^{(c)}, \{\tilde{\mathbf{LH}}_j^{(c)}, \tilde{\mathbf{HL}}_j^{(c)}, \tilde{\mathbf{HH}}_j^{(c)}\}\} = \text{CrossAttn}(\text{AllSubbands}^{(c)}) \]

Step 3: 2D Reconstruction

Reconstruct the spatial signal:

\[ \tilde{\mathbf{X}}^{(c)} = \text{IDWT2D}_J(\{\tilde{\mathbf{LL}}_J^{(c)}, \{\tilde{\mathbf{LH}}_j^{(c)}, \tilde{\mathbf{HL}}_j^{(c)}, \tilde{\mathbf{HH}}_j^{(c)}\}\}) \]

Step 4: Channel Concatenation and Residual

\[ \hat{\mathbf{X}} = \text{Stack}(\{\tilde{\mathbf{X}}^{(c)}\}_{c=0}^{C-1}) \in \mathbb{R}^{B \times C \times H \times W} \]
\[ \mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}} \]
Complexity Analysis
  • Time Complexity: \(O(CHW \cdot J) + O(\text{mixing operations})\)
  • Space Complexity: \(O(CHW + \text{subband storage})\)

Where mixing complexity depends on mode: - Subband: \(O(\text{conv operations per subband})\) - Cross: \(O(\text{attention across subbands})\) - Attention: \(O(\text{transformer encoder})\)

Parameters:

Name Type Description Default
channels int

Number of input/output channels \(C\).

required
wavelet str

Wavelet type determining 2D filter bank characteristics.

'db4'
levels int

Number of decomposition levels \(J\).

2
mixing_mode str

Subband mixing strategy: 'subband' (independent), 'cross' (attention), 'attention' (transformer).

'subband'

Attributes:

Name Type Description
dwt DWT2D

2D wavelet transform module.

ll_mixer Sequential

Convolutional network for LL subband (subband mode).

detail_mixers ModuleList

Convolutional networks for detail subbands (subband mode).

cross_mixer MultiheadAttention

Cross-attention module (cross mode).

subband_attention TransformerEncoder

Transformer encoder for subband attention (attention mode).

Raises:

Type Description
ValueError

If :attr:mixing_mode is not one of {'subband', 'cross', 'attention'}.

Examples:

Independent subband processing:

>>> mixer = WaveletMixing2D(channels=256, wavelet='db4', levels=2)
>>> x = torch.randn(32, 256, 64, 64)  # (batch, channels, height, width)
>>> output = mixer(x)
>>> assert output.shape == x.shape

Cross-subband attention:

>>> mixer = WaveletMixing2D(channels=128, mixing_mode='cross', levels=3)
>>> x = torch.randn(16, 128, 128, 128)
>>> output = mixer(x)  # Attention applied across all wavelet subbands

Methods:

Name Description
forward

Apply 2D wavelet-based mixing following the mathematical formulation.

from_config

Create WaveletMixing2D from configuration.

Source code in spectrans/layers/mixing/wavelet.py
def __init__(
    self,
    channels: int,
    wavelet: WaveletType = "db4",
    levels: int = 2,
    mixing_mode: str = "subband",
):
    super().__init__()

    self.channels = channels
    self.wavelet = wavelet
    self.levels = levels
    self.mixing_mode = mixing_mode

    # Initialize 2D wavelet transform
    self.dwt = DWT2D(wavelet=wavelet, levels=levels, mode="symmetric")

    # Initialize mixing layers based on mode
    if mixing_mode == "subband":
        # Independent processing of each subband
        # Each subband from DWT has 1 channel, so conv layers should expect 1 channel input
        self.ll_mixer = nn.Sequential(
            nn.Conv2d(1, 1, 3, padding=1),  # 1 channel in/out for single subband
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )

        self.detail_mixers = nn.ModuleList()
        for _ in range(levels):
            detail_mixer = nn.ModuleDict(
                {
                    "lh": nn.Conv2d(1, 1, 3, padding=1),  # 1 channel in/out per detail subband
                    "hl": nn.Conv2d(1, 1, 3, padding=1),
                    "hh": nn.Conv2d(1, 1, 3, padding=1),
                }
            )
            self.detail_mixers.append(detail_mixer)

    elif mixing_mode == "cross":
        # Cross-subband interaction
        # Each subband is processed per-channel with feature dimension 1 after flattening spatial dims
        # So attention operates on sequences of spatial positions with 1 feature per position
        self.cross_mixer = nn.MultiheadAttention(
            1,
            num_heads=1,
            batch_first=True,  # Feature dim=1, so only 1 head possible
        )

    elif mixing_mode == "attention":
        # Attention-based mixing across all subbands
        # Same as cross mode - feature dimension is 1 after spatial flattening
        self.subband_attention = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=1,  # Feature dimension is 1 after flattening spatial dimensions
                nhead=1,  # Only 1 head possible with d_model=1
                dim_feedforward=4,  # Minimal FFN since d_model=1
                batch_first=True,
            ),
            num_layers=2,
        )
    else:
        raise ValueError(f"Unknown mixing mode: {mixing_mode}")
Functions
forward
forward(x: Tensor) -> Tensor

Apply 2D wavelet-based mixing following the mathematical formulation.

Implements complete 2D wavelet mixing: spatial decomposition → subband mixing → reconstruction → residual connection. Each channel is processed independently.

Mathematical Implementation
  1. Channel Extraction: \(\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]\) for each channel \(c\)
  2. 2D Wavelet Decomposition: \(\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}\)
  3. Subband Mixing: Apply mode-specific transformations to wavelet subbands
  4. 2D Reconstruction: \(\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}\)
  5. Channel Stacking: \(\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]\)
  6. Residual Connection: \(\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}\)

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape \((B, C, H, W)\) where:

  • \(B\) is batch size
  • \(C\) is number of channels
  • \(H\) is height
  • \(W\) is width
required

Returns:

Type Description
Tensor

Mixed output tensor of identical shape \((B, C, H, W)\) with 2D wavelet-domain mixing applied and residual connection.

Notes
  • Spatial dimensions preserved through careful reconstruction handling
  • Different mixing strategies provide various inductive biases
  • Subband mode: Independent processing emphasizes local features
  • Cross mode: Attention enables global subband interactions
  • Attention mode: Full transformer encoder for complex dependencies
Source code in spectrans/layers/mixing/wavelet.py
def forward(self, x: Tensor) -> Tensor:
    r"""Apply 2D wavelet-based mixing following the mathematical formulation.

    Implements complete 2D wavelet mixing: spatial decomposition → subband mixing →
    reconstruction → residual connection. Each channel is processed independently.

    Mathematical Implementation
    ---------------------------
    1. **Channel Extraction**: $\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]$ for each channel $c$
    2. **2D Wavelet Decomposition**: $\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}$
    3. **Subband Mixing**: Apply mode-specific transformations to wavelet subbands
    4. **2D Reconstruction**: $\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}$
    5. **Channel Stacking**: $\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]$
    6. **Residual Connection**: $\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}$

    Parameters
    ----------
    x : Tensor
        Input tensor of shape $(B, C, H, W)$ where:

        - $B$ is batch size
        - $C$ is number of channels
        - $H$ is height
        - $W$ is width

    Returns
    -------
    Tensor
        Mixed output tensor of identical shape $(B, C, H, W)$ with 2D wavelet-domain
        mixing applied and residual connection.

    Notes
    -----
    - Spatial dimensions preserved through careful reconstruction handling
    - Different mixing strategies provide various inductive biases
    - Subband mode: Independent processing emphasizes local features
    - Cross mode: Attention enables global subband interactions
    - Attention mode: Full transformer encoder for complex dependencies
    """
    _, channels, height, width = x.shape
    residual = x

    # Process each channel
    outputs = []
    for c in range(channels):
        x_channel = x[:, c : c + 1, :, :]

        # Decompose using 2D DWT
        ll, details = self.dwt.decompose(x_channel, dim=(-2, -1))

        # Apply mixing based on mode
        if self.mixing_mode == "subband":
            # Process LL subband
            ll_mixed = self.ll_mixer(ll)

            # Process detail subbands
            details_mixed = []
            for level, (lh, hl, hh) in enumerate(details):
                mixer = self.detail_mixers[level]
                lh_mixed = mixer["lh"](lh)  # type: ignore
                hl_mixed = mixer["hl"](hl)  # type: ignore
                hh_mixed = mixer["hh"](hh)  # type: ignore
                details_mixed.append((lh_mixed, hl_mixed, hh_mixed))

        elif self.mixing_mode == "cross":
            # Flatten spatial dimensions for attention
            ll_flat = ll.flatten(2).transpose(1, 2)
            details_flat = []
            for lh, hl, hh in details:
                details_flat.extend(
                    [
                        lh.flatten(2).transpose(1, 2),
                        hl.flatten(2).transpose(1, 2),
                        hh.flatten(2).transpose(1, 2),
                    ]
                )

            # Apply cross-attention
            all_subbands = torch.cat([ll_flat, *details_flat], dim=1)
            mixed, _ = self.cross_mixer(all_subbands, all_subbands, all_subbands)

            # Reshape back
            ll_size = ll.shape[2] * ll.shape[3]
            ll_mixed = mixed[:, :ll_size, :].transpose(1, 2).reshape_as(ll)

            details_mixed = []
            offset = ll_size
            for _level, (lh, hl, hh) in enumerate(details):
                lh_size = lh.shape[2] * lh.shape[3]
                hl_size = hl.shape[2] * hl.shape[3]
                hh_size = hh.shape[2] * hh.shape[3]

                lh_mixed = mixed[:, offset : offset + lh_size, :].transpose(1, 2).reshape_as(lh)
                offset += lh_size
                hl_mixed = mixed[:, offset : offset + hl_size, :].transpose(1, 2).reshape_as(hl)
                offset += hl_size
                hh_mixed = mixed[:, offset : offset + hh_size, :].transpose(1, 2).reshape_as(hh)
                offset += hh_size

                details_mixed.append((lh_mixed, hl_mixed, hh_mixed))

        else:  # attention mode
            # Similar to cross but with transformer encoder
            ll_mixed = ll
            details_mixed = details

        # Reconstruct
        reconstructed = self.dwt.reconstruct((ll_mixed, details_mixed), dim=(-2, -1))

        # Ensure correct shape
        if reconstructed.shape[-2:] != (height, width):
            reconstructed = reconstructed[:, :, :height, :width]

        outputs.append(reconstructed)

    # Combine channels
    output = torch.cat(outputs, dim=1)

    # Residual connection
    output = output + residual

    return output
from_config classmethod
from_config(config: WaveletMixing2DConfig) -> WaveletMixing2D

Create WaveletMixing2D from configuration.

Parameters:

Name Type Description Default
config WaveletMixing2DConfig

Typed and validated configuration.

required

Returns:

Type Description
WaveletMixing2D

Configured instance.

Source code in spectrans/layers/mixing/wavelet.py
@classmethod
def from_config(cls, config: "WaveletMixing2DConfig") -> "WaveletMixing2D":
    """Create WaveletMixing2D from configuration.

    Parameters
    ----------
    config : WaveletMixing2DConfig
        Typed and validated configuration.

    Returns
    -------
    WaveletMixing2D
        Configured instance.
    """
    return cls(
        channels=config.channels,
        wavelet=config.wavelet,
        levels=config.levels,
        mixing_mode=config.mixing_mode,
    )