Skip to content

Layer Implementations

spectrans.layers

Layer implementations for spectral transformers.

Provides spectral transformer layers that replace traditional attention mechanisms with spectral operations. The layers are organized into three categories: mixing layers, attention layers, and neural operators for different use cases with standard transformer architecture compatibility.

Modules:

Name Description
attention

Spectral attention mechanisms with linear complexity.

mixing

Token mixing layers using spectral transforms.

operators

Fourier neural operators for function space learning.

Classes:

Name Description
AdaptiveGlobalFilter

Enhanced global filter with adaptive initialization.

AFNOMixing

Adaptive Fourier Neural Operator with mode truncation.

DCTAttention

Specialized LST attention using discrete cosine transform.

FilterMixingLayer

Base class for learnable frequency domain filters.

FNOBlock

FNO block with spectral convolution and feedforward.

FourierMixing

2D FFT mixing for both sequence and feature dimensions (FNet).

FourierMixing1D

1D FFT mixing along sequence dimension only.

FourierNeuralOperator

Base FNO layer for learning operators in function spaces.

GlobalFilterMixing

Learnable complex filters in frequency domain (GFNet).

GlobalFilterMixing2D

2D variant with filtering in both dimensions.

HadamardAttention

Fast attention using Hadamard transform operations.

KernelAttention

General kernel-based attention with various kernel options.

LSTAttention

Linear Spectral Transform attention with configurable transforms.

MixedSpectralAttention

Multi-transform attention combining multiple spectral methods.

MixingLayer

Base class for spectral mixing operations.

PerformerAttention

Performer-style attention with FAVOR+ algorithm.

RealFourierMixing

Memory-efficient real FFT variant for real-valued inputs.

SeparableFourierMixing

Configurable sequence and/or feature mixing.

SpectralAttention

Multi-head spectral attention using random Fourier features.

SpectralConv1d

1D spectral convolution operator for sequence data.

SpectralConv2d

2D spectral convolution operator for image-like data.

UnitaryMixingLayer

Base class for energy-preserving mixing transforms.

WaveletMixing

1D wavelet mixing using discrete wavelet transform.

WaveletMixing2D

2D wavelet mixing for spatial data processing.

Examples:

Basic Fourier mixing layer (FNet-style):

>>> import torch
>>> from spectrans.layers import FourierMixing
>>>
>>> # Create Fourier mixing layer
>>> mixer = FourierMixing(hidden_dim=768)
>>> x = torch.randn(32, 512, 768)  # (batch, sequence, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape

Global filter mixing with learnable parameters:

>>> from spectrans.layers import GlobalFilterMixing
>>>
>>> # Create global filter with learnable complex weights
>>> filter_layer = GlobalFilterMixing(
...     hidden_dim=512,
...     sequence_length=1024,
...     activation='sigmoid'
... )
>>> x = torch.randn(16, 1024, 512)
>>> output = filter_layer(x)

Spectral attention with random Fourier features:

>>> from spectrans.layers import SpectralAttention
>>>
>>> # Create spectral attention layer
>>> attention = SpectralAttention(
...     hidden_dim=768,
...     num_heads=12,
...     num_features=256
... )
>>> x = torch.randn(8, 256, 768)
>>> output = attention(x)
Notes

Layer Categories and Complexity:

Mixing layers have \(O(n \log n)\) or \(O(n)\) complexity. Parameter-free variants use FFT operations, while learnable filters like global filters and AFNO include trainable parameters. Multiresolution approaches use wavelet transforms for hierarchical processing.

Attention layers achieve linear \(O(n)\) complexity through kernel approximation with Random Fourier Features and orthogonal features, transform-based methods using DCT, DST, and Hadamard transforms, or hybrid approaches combining multiple transforms with learnable mixing.

Neural operators have \(O(k \cdot d^2 + n \log n)\) complexity where \(k\) is the number of modes and \(d\) is the dimension. These operators map between infinite-dimensional function spaces with resolution-invariant learning independent of discretization through spectral parameterization in the Fourier domain.

All layers use the convolution theorem for global mixing:

\[ \mathcal{F}[f \star g] = \mathcal{F}[f] \odot \mathcal{F}[g] \]

This replaces quadratic attention \(O(n^2)\) with logarithmic or linear complexity spectral operations.

References

James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. 2022. FNet: Mixing tokens with Fourier transforms. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT), pages 4296-4313, Seattle.

Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, and Jie Zhou. 2021. Global filter networks for image classification. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021), pages 980-993.

John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, and Bryan Catanzaro. 2022. Adaptive Fourier neural operators: Efficient token mixers for transformers. In Proceedings of the International Conference on Learning Representations (ICLR).

Zongyi Li, Nikola Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhattacharya, Andrew Stuart, and Anima Anandkumar. 2021. Fourier neural operator for parametric partial differential equations. In Proceedings of the International Conference on Learning Representations (ICLR).

Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, and Adrian Weller. 2021. Rethinking attention with performers. In Proceedings of the International Conference on Learning Representations (ICLR).

See Also

spectrans.transforms : Underlying spectral transform implementations. spectrans.models : Model implementations using these layers. spectrans.blocks : Transformer blocks that compose these layers.

Classes

DCTAttention

DCTAttention(hidden_dim: int, num_heads: int = 8, dct_type: int = 2, learnable_scale: bool = True, dropout: float = 0.0)

Bases: LSTAttention

Attention using Discrete Cosine Transform.

Specialized LST attention that uses DCT for all heads for real-valued signals with energy compaction.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

8
dct_type int

DCT type (2 is most common).

2
learnable_scale bool

Whether to use learnable scaling.

True
dropout float

Dropout probability.

0.0
Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    dct_type: int = 2,
    learnable_scale: bool = True,
    dropout: float = 0.0,
):
    super().__init__(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        transform_type="dct",
        learnable_scale=learnable_scale,
        normalize=True,
        dropout=dropout,
    )

    self.dct_type = dct_type

    # Override transform with specific DCT type
    # Note: Current DCT implementation only supports type 2
    # Future versions may support other types
    if dct_type != 2:
        # For now, still use type 2 DCT
        pass

HadamardAttention

HadamardAttention(hidden_dim: int, num_heads: int = 8, scale_by_sqrt: bool = True, learnable_scale: bool = True, dropout: float = 0.0)

Bases: LSTAttention

Attention using fast Hadamard transform.

Uses Hadamard transform for \(O(n \log n)\) attention computation with binary coefficients.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

8
scale_by_sqrt bool

Whether to scale by sqrt(n) for orthogonality.

True
learnable_scale bool

Whether to use learnable diagonal scaling.

True
dropout float

Dropout probability.

0.0
Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    scale_by_sqrt: bool = True,
    learnable_scale: bool = True,
    dropout: float = 0.0,
):
    super().__init__(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        transform_type="hadamard",
        learnable_scale=learnable_scale,
        normalize=True,
        dropout=dropout,
    )

    self.scale_by_sqrt = scale_by_sqrt

    # Additional scaling for Hadamard
    if scale_by_sqrt:
        # Initialize scale with 1/sqrt(n) factor
        with torch.no_grad():
            self.scale.data = self.scale.data / math.sqrt(self.head_dim)

KernelAttention

KernelAttention(hidden_dim: int, num_heads: int = 8, kernel_type: Literal['gaussian', 'polynomial', 'spectral'] = 'gaussian', rank: int | None = None, num_features: int | None = None, dropout: float = 0.0)

Bases: AttentionLayer

General kernel-based attention with various kernel options.

Supports multiple kernel types including Gaussian, polynomial, and learnable spectral kernels.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of heads.

8
kernel_type Literal['gaussian', 'polynomial', 'spectral']

Type of kernel to use.

"gaussian"
rank int | None

Rank for low-rank approximations.

None
num_features int | None

Number of features for RFF kernels.

None
dropout float

Dropout probability.

0.0

Attributes:

Name Type Description
kernel_type str

Type of kernel being used.

rank int | None

Rank for approximations.

Methods:

Name Description
forward

Forward pass of kernel attention.

Source code in spectrans/layers/attention/spectral.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    kernel_type: Literal["gaussian", "polynomial", "spectral"] = "gaussian",
    rank: int | None = None,
    num_features: int | None = None,
    dropout: float = 0.0,
):
    super().__init__(hidden_dim, num_heads, dropout)

    self.head_dim = hidden_dim // num_heads
    self.kernel_type = kernel_type
    self.rank = rank or min(64, self.head_dim)

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    # Initialize kernel - using union type to handle different kernel types
    if kernel_type == "gaussian":
        from ...kernels import GaussianRFFKernel

        self.kernel = GaussianRFFKernel(
            input_dim=self.head_dim,
            num_features=num_features or self.head_dim,
            sigma=1.0 / math.sqrt(self.head_dim),
        )
        self.use_features = True

    elif kernel_type == "polynomial":
        from ...kernels import PolynomialSpectralKernel

        self.kernel = PolynomialSpectralKernel(
            rank=self.rank,
            degree=2,
        )
        self.use_features = False

    else:  # spectral
        from ...kernels import LearnableSpectralKernel

        self.kernel = LearnableSpectralKernel(
            input_dim=self.head_dim,
            rank=self.rank,
            trainable_eigenvectors=True,
        )
        self.use_features = True
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of kernel attention.

Parameters:

Name Type Description Default
x Tensor

Input of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask.

None
return_attention bool

Whether to return attention weights.

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output and optionally attention weights.

Source code in spectrans/layers/attention/spectral.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of kernel attention.

    Parameters
    ----------
    x : Tensor
        Input of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask.
    return_attention : bool, default=False
        Whether to return attention weights.

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output and optionally attention weights.
    """
    batch_size, seq_len, _ = x.shape

    # Projections and reshape
    Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    if self.use_features:
        # Use feature-based approximation - kernel should be a callable (RandomFeatureMap)
        if hasattr(self.kernel, "extract_features"):
            Q_feat = self.kernel.extract_features(Q)  # type: ignore[operator]
            K_feat = self.kernel.extract_features(K)  # type: ignore[operator]
        else:
            # Call the kernel as a function
            Q_feat = self.kernel(Q)  # type: ignore[operator]
            K_feat = self.kernel(K)  # type: ignore[operator]

        if mask is not None:
            mask_exp = mask.unsqueeze(1).unsqueeze(-1)
            K_feat = K_feat.masked_fill(~mask_exp, 0)
            V = V.masked_fill(~mask_exp, 0)

        # Linear attention computation
        KV = torch.matmul(K_feat.transpose(-2, -1), V)
        out: Tensor = torch.matmul(Q_feat, KV)

        # Normalize
        K_sum = K_feat.sum(dim=-2, keepdim=True).transpose(-2, -1)
        normalizer = torch.matmul(Q_feat, K_sum) + 1e-6
        out = out / normalizer

        attention_weights: Tensor | None = None

    else:
        # Direct kernel computation (for small sequences)
        # Flatten heads and batch for kernel computation
        Q_flat = Q.reshape(-1, seq_len, self.head_dim)
        K_flat = K.reshape(-1, seq_len, self.head_dim)

        # Compute kernel matrix
        attention_weights = self.kernel.compute(Q_flat, K_flat)  # type: ignore[operator]
        attention_weights = attention_weights.view(batch_size, self.num_heads, seq_len, seq_len)

        if mask is not None:
            mask_exp = mask.unsqueeze(1).unsqueeze(2)
            attention_weights = attention_weights.masked_fill(~mask_exp, -1e9)

        # Normalize
        attention_weights = F.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply to values
        out = torch.matmul(attention_weights, V)

    # Reshape output
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)
    out = self.out_proj(out)
    out = self.dropout(out)

    if return_attention:
        return out, attention_weights  # type: ignore[return-value]
    return out

LSTAttention

LSTAttention(hidden_dim: int, num_heads: int = 8, transform_type: Literal['dct', 'dst', 'hadamard', 'mixed'] = 'dct', learnable_scale: bool = True, normalize: bool = True, dropout: float = 0.0, use_bias: bool = True)

Bases: AttentionLayer

Linear Spectral Transform attention mechanism.

Implements attention using orthogonal transforms (DCT, DST, Hadamard) with learnable diagonal scaling.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
num_heads int

Number of attention heads.

8
transform_type Literal['dct', 'dst', 'hadamard', 'mixed']

Type of transform to use. "mixed" uses different transforms per head.

"dct"
learnable_scale bool

Whether to use learnable diagonal scaling matrix.

True
normalize bool

Whether to normalize in transform domain.

True
dropout float

Dropout probability.

0.0
use_bias bool

Whether to use bias in projections.

True

Attributes:

Name Type Description
head_dim int

Dimension per attention head.

transform_type str

Type of transform being used.

transforms ModuleList

List of transforms (one per head if mixed).

scale Parameter | None

Learnable diagonal scaling if enabled.

Methods:

Name Description
forward

Forward pass of LST attention.

Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    transform_type: Literal["dct", "dst", "hadamard", "mixed"] = "dct",
    learnable_scale: bool = True,
    normalize: bool = True,
    dropout: float = 0.0,
    use_bias: bool = True,
):
    super().__init__(hidden_dim, num_heads, dropout)

    self.head_dim = hidden_dim // num_heads
    assert self.head_dim * num_heads == hidden_dim, (
        f"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
    )

    self.transform_type = transform_type
    self.normalize = normalize

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)

    # Initialize transforms
    self.transforms: nn.ModuleList = nn.ModuleList()  # Contains SpectralTransform objects
    if transform_type == "mixed":
        # Use different transforms for different heads
        transform_types = ["dct", "dst", "hadamard"]
        for i in range(num_heads):
            t_type = transform_types[i % len(transform_types)]
            self.transforms.append(self._create_transform(t_type))
    else:
        # Use same transform for all heads
        transform = self._create_transform(transform_type)
        for _ in range(num_heads):
            self.transforms.append(transform)

    # Learnable diagonal scaling
    if learnable_scale:
        # Different scale per head and position
        self.scale = nn.Parameter(torch.ones(num_heads, 1, self.head_dim))
    else:
        self.register_buffer("scale", torch.ones(num_heads, 1, self.head_dim))
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of LST attention.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask of shape (batch_size, seq_len).

None
return_attention bool

Whether to return attention weights (not supported).

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, returns (output, None).

Source code in spectrans/layers/attention/lst.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of LST attention.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask of shape (batch_size, seq_len).
    return_attention : bool, default=False
        Whether to return attention weights (not supported).

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output tensor of shape (batch_size, seq_len, hidden_dim).
        If return_attention=True, returns (output, None).
    """
    batch_size, seq_len, _ = x.shape

    # Linear projections
    Q = self.q_proj(x)
    K = self.k_proj(x)
    V = self.v_proj(x)

    # Reshape for multi-head attention
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # Transpose to (batch, heads, seq_len, head_dim)
    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    # Apply transforms per head
    outputs = []
    for head_idx in range(self.num_heads):
        q_head = Q[:, head_idx]  # (batch, seq_len, head_dim)
        k_head = K[:, head_idx]
        v_head = V[:, head_idx]

        # Get transform for this head
        transform: SpectralTransform
        if self.transform_type == "mixed":
            transform = self.transforms[head_idx]  # type: ignore[assignment]
        else:
            transform = self.transforms[0]  # type: ignore[assignment]

        # Apply transform along sequence dimension
        q_transformed = transform.transform(q_head, dim=-2)
        k_transformed = transform.transform(k_head, dim=-2)
        v_transformed = transform.transform(v_head, dim=-2)

        # Apply mask in transform domain if provided
        if mask is not None:
            # Transform mask to frequency domain
            mask_float = mask.float().unsqueeze(-1)  # (batch, seq_len, 1)
            mask_transformed = transform.transform(mask_float, dim=-2)
            k_transformed = k_transformed * mask_transformed
            v_transformed = v_transformed * mask_transformed

        # Element-wise multiplication in transform domain
        # This replaces the QK^T computation
        attention_transformed = q_transformed * k_transformed * self.scale[head_idx]

        # Apply to values
        output_transformed = attention_transformed * v_transformed

        # Normalize if requested
        if self.normalize:
            # Compute normalization factor
            norm_factor = torch.abs(attention_transformed).sum(dim=-1, keepdim=True) + 1e-6
            output_transformed = output_transformed / norm_factor

        # Inverse transform
        output_head = transform.inverse_transform(output_transformed, dim=-2)

        # Real part for numerical stability
        if torch.is_complex(output_head):
            output_head = output_head.real

        outputs.append(output_head.unsqueeze(1))

    # Concatenate heads
    out = torch.cat(outputs, dim=1)  # (batch, heads, seq_len, head_dim)

    # Reshape
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)

    # Output projection and dropout
    out = self.out_proj(out)
    out = self.dropout(out)

    output: Tensor = out
    if return_attention:
        # Attention weights not available in LST
        return output, None  # type: ignore[return-value]
    return output

MixedSpectralAttention

MixedSpectralAttention(hidden_dim: int, num_heads: int = 9, use_fft: bool = True, use_dct: bool = True, use_hadamard: bool = True, dropout: float = 0.0)

Bases: AttentionLayer

Mixed spectral attention using multiple transform types.

Combines different spectral transforms across heads for diverse frequency representations.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads (should be divisible by 3 for even split).

8
use_fft bool

Whether to include FFT heads.

True
use_dct bool

Whether to include DCT heads.

True
use_hadamard bool

Whether to include Hadamard heads.

True
dropout float

Dropout probability.

0.0

Methods:

Name Description
forward

Forward pass of mixed spectral attention.

Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 9,  # Divisible by 3
    use_fft: bool = True,
    use_dct: bool = True,
    use_hadamard: bool = True,
    dropout: float = 0.0,
):
    super().__init__(hidden_dim, num_heads, dropout)

    self.head_dim = hidden_dim // num_heads

    # Count enabled transform types
    enabled_transforms = []
    if use_fft:
        enabled_transforms.append("fft")
    if use_dct:
        enabled_transforms.append("dct")
    if use_hadamard:
        enabled_transforms.append("hadamard")

    if not enabled_transforms:
        raise ValueError("At least one transform type must be enabled")

    self.enabled_transforms = enabled_transforms

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    # Assign transforms to heads
    self.head_transforms = []
    for i in range(num_heads):
        transform_type = enabled_transforms[i % len(enabled_transforms)]
        self.head_transforms.append(transform_type)

    # Create transform modules
    from ...transforms import FFT1D

    self.fft = FFT1D() if use_fft else None
    self.dct = DCT(normalized=True) if use_dct else None
    self.hadamard = HadamardTransform(normalized=True) if use_hadamard else None

    # Learnable scales per transform type
    self.scales = nn.ParameterDict(
        {t: nn.Parameter(torch.ones(1, 1, self.head_dim)) for t in enabled_transforms}
    )
Functions
forward
forward(x: Tensor, _mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of mixed spectral attention.

Parameters:

Name Type Description Default
x Tensor

Input of shape (batch_size, seq_len, hidden_dim).

required
_mask Tensor | None

Attention mask (not implemented for spectral attention).

None
return_attention bool

Whether to return attention weights.

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output and optionally None for weights.

Source code in spectrans/layers/attention/lst.py
def forward(
    self,
    x: Tensor,
    _mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of mixed spectral attention.

    Parameters
    ----------
    x : Tensor
        Input of shape (batch_size, seq_len, hidden_dim).
    _mask : Tensor | None, default=None
        Attention mask (not implemented for spectral attention).
    return_attention : bool, default=False
        Whether to return attention weights.

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output and optionally None for weights.
    """
    batch_size, seq_len, _ = x.shape

    # Projections
    Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    # Process each head with its assigned transform
    outputs = []
    for head_idx in range(self.num_heads):
        transform_type = self.head_transforms[head_idx]
        scale = self.scales[transform_type]

        q_head = Q[:, head_idx]
        k_head = K[:, head_idx]
        v_head = V[:, head_idx]

        # Apply appropriate transform
        if transform_type == "fft":
            if self.fft is None:
                raise RuntimeError("FFT transform not initialized")
            q_t = self.fft.transform(q_head, dim=-2)
            k_t = self.fft.transform(k_head, dim=-2)
            v_t = self.fft.transform(v_head, dim=-2)

            # Complex multiplication in frequency domain
            attn_t = q_t * k_t.conj() * scale
            out_t = attn_t * v_t

            # Inverse transform
            out_head = self.fft.inverse_transform(out_t, dim=-2).real

        elif transform_type == "dct":
            if self.dct is None:
                raise RuntimeError("DCT transform not initialized")
            q_t = self.dct.transform(q_head, dim=-2)
            k_t = self.dct.transform(k_head, dim=-2)
            v_t = self.dct.transform(v_head, dim=-2)

            attn_t = q_t * k_t * scale
            out_t = attn_t * v_t

            out_head = self.dct.inverse_transform(out_t, dim=-2)

        else:  # hadamard
            if self.hadamard is None:
                raise RuntimeError("Hadamard transform not initialized")
            q_t = self.hadamard.transform(q_head, dim=-2)
            k_t = self.hadamard.transform(k_head, dim=-2)
            v_t = self.hadamard.transform(v_head, dim=-2)

            attn_t = q_t * k_t * scale
            out_t = attn_t * v_t

            out_head = self.hadamard.inverse_transform(out_t, dim=-2)

        outputs.append(out_head.unsqueeze(1))

    # Concatenate and reshape
    out = torch.cat(outputs, dim=1)
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)

    # Output projection
    out = self.out_proj(out)
    out = self.dropout(out)

    output: Tensor = out
    if return_attention:
        return output, None  # type: ignore[return-value]
    return output

PerformerAttention

PerformerAttention(hidden_dim: int, num_heads: int = 8, num_features: int | None = None, generalized: bool = False, dropout: float = 0.0)

Bases: SpectralAttention

Performer-style attention with FAVOR+ algorithm.

Implements the Performer architecture with positive orthogonal random features (FAVOR+) for softmax kernel approximation.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

8
num_features int | None

Number of random features.

None
generalized bool

Whether to use generalized attention (without softmax).

False
dropout float

Dropout probability.

0.0

Attributes:

Name Type Description
generalized bool

Whether using generalized attention.

Methods:

Name Description
forward

Forward pass of Performer attention.

Source code in spectrans/layers/attention/spectral.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    num_features: int | None = None,
    generalized: bool = False,
    dropout: float = 0.0,
):
    super().__init__(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        num_features=num_features,
        kernel_type="softmax",
        use_orthogonal=True,
        feature_redraw=False,
        dropout=dropout,
    )

    self.generalized = generalized

    if generalized:
        # For generalized attention, use different kernel
        self.kernel = RFFAttentionKernel(
            input_dim=self.head_dim,
            num_features=self.num_features,
            kernel_type="relu",
            use_orthogonal=True,
        )
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of Performer attention.

Parameters:

Name Type Description Default
x Tensor

Input of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask.

None
return_attention bool

Whether to return attention weights.

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output tensor and optionally None for weights.

Source code in spectrans/layers/attention/spectral.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of Performer attention.

    Parameters
    ----------
    x : Tensor
        Input of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask.
    return_attention : bool, default=False
        Whether to return attention weights.

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output tensor and optionally None for weights.
    """
    if self.generalized:
        # Generalized attention without exponential
        return self._generalized_attention(x, mask)
    else:
        # Standard Performer with softmax approximation
        return super().forward(x, mask, return_attention)

SpectralAttention

SpectralAttention(hidden_dim: int, num_heads: int = 8, num_features: int | None = None, head_dim: int | None = None, kernel_type: Literal['gaussian', 'softmax'] = 'softmax', use_orthogonal: bool = True, feature_redraw: bool = False, dropout: float = 0.0, use_bias: bool = True)

Bases: AttentionLayer

Multi-head spectral attention using RFF approximation.

Implements attention using Random Fourier Features to approximate the softmax kernel with linear complexity.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
num_heads int

Number of attention heads.

8
num_features int | None

Number of random features. If None, uses hidden_dim.

None
head_dim int | None

Dimension per head. If None, uses hidden_dim // num_heads.

None
kernel_type Literal['gaussian', 'softmax']

Type of kernel to approximate.

"softmax"
use_orthogonal bool

Whether to use orthogonal random features.

True
feature_redraw bool

Whether to redraw features at each forward pass.

False
dropout float

Dropout probability.

0.0
use_bias bool

Whether to use bias in projections.

True

Attributes:

Name Type Description
head_dim int

Dimension per attention head.

num_features int

Number of random features used.

q_proj Linear

Query projection.

k_proj Linear

Key projection.

v_proj Linear

Value projection.

out_proj Linear

Output projection.

kernel RandomFeatureMap | KernelFunction

Kernel for attention approximation.

Methods:

Name Description
forward

Forward pass of spectral attention.

Source code in spectrans/layers/attention/spectral.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    num_features: int | None = None,
    head_dim: int | None = None,
    kernel_type: Literal["gaussian", "softmax"] = "softmax",
    use_orthogonal: bool = True,
    feature_redraw: bool = False,
    dropout: float = 0.0,
    use_bias: bool = True,
):
    super().__init__(hidden_dim, num_heads, dropout)

    # Determine head dimension
    self.head_dim = head_dim or (hidden_dim // num_heads)
    assert self.head_dim * num_heads == hidden_dim, (
        f"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
    )

    # Number of random features
    self.num_features = num_features or self.head_dim

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)

    # Kernel for approximation
    if kernel_type == "softmax":
        self.kernel = RFFAttentionKernel(
            input_dim=self.head_dim,
            num_features=self.num_features,
            kernel_type="softmax",
            use_orthogonal=use_orthogonal,
            redraw=feature_redraw,
        )
    else:  # gaussian
        self.kernel = GaussianRFFKernel(
            input_dim=self.head_dim,
            num_features=self.num_features,
            sigma=1.0 / math.sqrt(self.head_dim),
            orthogonal=use_orthogonal,
        )

    # Normalization
    self.scale = 1.0 / math.sqrt(self.num_features)
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of spectral attention.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask of shape (batch_size, seq_len).

None
return_attention bool

Whether to return attention weights (not supported).

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, also returns None (weights not available).

Source code in spectrans/layers/attention/spectral.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of spectral attention.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask of shape (batch_size, seq_len).
    return_attention : bool, default=False
        Whether to return attention weights (not supported).

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output tensor of shape (batch_size, seq_len, hidden_dim).
        If return_attention=True, also returns None (weights not available).
    """
    batch_size, seq_len, _ = x.shape

    # Linear projections
    Q = self.q_proj(x)
    K = self.k_proj(x)
    V = self.v_proj(x)

    # Reshape for multi-head attention
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # Transpose to (batch, heads, seq_len, head_dim)
    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    # Apply kernel feature maps
    Q_features = self.kernel(Q)  # (batch, heads, seq_len, num_features)
    K_features = self.kernel(K)  # (batch, heads, seq_len, num_features)

    # Apply mask if provided
    if mask is not None:
        # Expand mask for heads dimension
        mask = mask.unsqueeze(1).unsqueeze(-1)  # (batch, 1, seq_len, 1)
        K_features = K_features.masked_fill(~mask, 0)
        V = V.masked_fill(~mask, 0)

    # Compute KV (batch, heads, num_features, head_dim)
    KV = torch.matmul(K_features.transpose(-2, -1), V)

    # Compute QKV (batch, heads, seq_len, head_dim)
    out: Tensor = torch.matmul(Q_features, KV)

    # Normalize
    # Compute normalizer: Q_features @ (K_features^T @ 1)
    K_sum = K_features.sum(dim=-2, keepdim=True).transpose(-2, -1)
    normalizer = torch.matmul(Q_features, K_sum) + 1e-6
    out = out / normalizer

    # Transpose back and reshape
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)

    # Output projection and dropout
    out = self.out_proj(out)
    out = self.dropout(out)

    if return_attention:
        # Attention weights not directly available in linear attention
        return out, None  # type: ignore[return-value]
    return out

AdaptiveGlobalFilter

AdaptiveGlobalFilter(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02, filter_regularization: float = 0.0, adaptive_initialization: bool = True, spectral_dropout_p: float = 0.0)

Bases: FilterMixingLayer

Adaptive Global Filter with regularization and smart initialization.

Enhanced version of global filtering with adaptive initialization strategies, regularization options, and improved training stability.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of input tensors.

required
sequence_length int

Expected sequence length.

required
activation ActivationType

Filter activation function.

"sigmoid"
dropout float

Dropout probability.

0.0
norm_eps float

Numerical stability epsilon.

1e-5
learnable_filters bool

Whether filters are learnable.

True
fft_norm str

FFT normalization.

"ortho"
filter_init_std float

Filter initialization standard deviation.

0.02
filter_regularization float

L2 regularization strength for filter parameters.

0.0
adaptive_initialization bool

Whether to use frequency-aware initialization.

True
spectral_dropout_p float

Spectral dropout probability in frequency domain.

0.0

Attributes:

Name Type Description
filter_regularization float

Regularization strength.

adaptive_initialization bool

Whether adaptive initialization is used.

spectral_dropout_p float

Spectral dropout probability.

spectral_dropout Module

Spectral dropout layer.

Methods:

Name Description
forward

Apply adaptive global filtering.

get_filter_response

Get adaptive frequency response.

get_regularization_loss

Compute L2 regularization loss for filter parameters.

get_spectral_properties

Get adaptive filter properties.

Source code in spectrans/layers/mixing/global_filter.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    activation: ActivationType = "sigmoid",
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
    fft_norm: FFTNorm = "ortho",
    filter_init_std: float = 0.02,
    filter_regularization: float = 0.0,
    adaptive_initialization: bool = True,
    spectral_dropout_p: float = 0.0,
):
    super().__init__(hidden_dim, sequence_length, dropout, norm_eps, learnable_filters)
    self.activation = activation
    self.filter_regularization = filter_regularization
    self.adaptive_initialization = adaptive_initialization
    self.spectral_dropout_p = spectral_dropout_p

    # Initialize filter parameters
    if adaptive_initialization:
        # Frequency-aware initialization: smaller values for high frequencies
        frequencies = torch.fft.fftfreq(sequence_length)
        # Weight by inverse frequency (avoiding DC component)
        freq_weights = 1.0 / (torch.abs(frequencies) + 0.1)
        freq_weights = freq_weights / freq_weights.mean()

        self.filter_real = nn.Parameter(
            torch.randn(sequence_length, hidden_dim)
            * filter_init_std
            * freq_weights.unsqueeze(-1)
        )
        self.filter_imag = nn.Parameter(
            torch.randn(sequence_length, hidden_dim)
            * filter_init_std
            * freq_weights.unsqueeze(-1)
        )
    else:
        # Standard initialization
        self.filter_real = nn.Parameter(
            torch.randn(sequence_length, hidden_dim) * filter_init_std
        )
        self.filter_imag = nn.Parameter(
            torch.randn(sequence_length, hidden_dim) * filter_init_std
        )

    # Store FFT transform as non-module attribute
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))

    self.activation_fn: Callable[[Tensor], Tensor]
    if activation == "sigmoid":
        self.activation_fn = nn.Sigmoid()
    elif activation == "tanh":
        self.activation_fn = nn.Tanh()
    elif activation == "identity":
        self.activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unknown activation: {activation}")

    # Spectral dropout for regularization
    if spectral_dropout_p > 0:
        self.spectral_dropout: nn.Module = nn.Dropout2d(spectral_dropout_p)
    else:
        self.spectral_dropout = nn.Identity()
Functions
forward
forward(x: Tensor) -> Tensor

Apply adaptive global filtering.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Adaptively filtered tensor.

Source code in spectrans/layers/mixing/global_filter.py
def forward(self, x: Tensor) -> Tensor:
    """Apply adaptive global filtering.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    Tensor
        Adaptively filtered tensor.
    """
    # Transform to frequency domain
    x_freq = self.fft1d.transform(x, dim=1)

    # Get actual sequence length
    seq_len = x_freq.shape[1]

    # Adapt filter to actual sequence length using interpolation
    if seq_len != self.sequence_length:
        # Use interpolation to adapt filters to the actual sequence length
        filter_real = (
            nn.functional.interpolate(
                self.filter_real.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)

        filter_imag = (
            nn.functional.interpolate(
                self.filter_imag.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)
    else:
        filter_real = self.filter_real
        filter_imag = self.filter_imag

    # Create complex filter with activation
    filter_complex = make_complex(
        self.activation_fn(filter_real), self.activation_fn(filter_imag)
    )

    # Apply spectral dropout to filter (not input)
    if self.training and self.spectral_dropout_p > 0:
        filter_complex = self.spectral_dropout(filter_complex)

    # Apply filtering in frequency domain
    filtered_freq = complex_multiply(x_freq, filter_complex)

    # Transform back to time domain
    filtered_time = self.fft1d.inverse_transform(filtered_freq, dim=1)

    # Take real part
    output = torch.real(filtered_time)

    # Apply standard dropout
    output = self.dropout(output)

    return output  # type: ignore[no-any-return]
get_filter_response
get_filter_response() -> Tensor

Get adaptive frequency response.

Returns:

Type Description
Tensor

Complex frequency response with current parameters.

Source code in spectrans/layers/mixing/global_filter.py
def get_filter_response(self) -> Tensor:
    """Get adaptive frequency response.

    Returns
    -------
    Tensor
        Complex frequency response with current parameters.
    """
    return make_complex(
        self.activation_fn(self.filter_real), self.activation_fn(self.filter_imag)
    )
get_regularization_loss
get_regularization_loss() -> Tensor

Compute L2 regularization loss for filter parameters.

Returns:

Type Description
Tensor

Scalar regularization loss.

Source code in spectrans/layers/mixing/global_filter.py
def get_regularization_loss(self) -> Tensor:
    """Compute L2 regularization loss for filter parameters.

    Returns
    -------
    Tensor
        Scalar regularization loss.
    """
    if self.filter_regularization <= 0:
        return torch.tensor(0.0, device=self.filter_real.device)

    real_loss = torch.norm(self.filter_real, p=2) ** 2
    imag_loss = torch.norm(self.filter_imag, p=2) ** 2

    return self.filter_regularization * (real_loss + imag_loss)  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool | int]

Get adaptive filter properties.

Returns:

Type Description
dict[str, str | bool | int]

Comprehensive properties including adaptive features.

Source code in spectrans/layers/mixing/global_filter.py
def get_spectral_properties(self) -> dict[str, str | bool | int]:
    """Get adaptive filter properties.

    Returns
    -------
    dict[str, str | bool | int]
        Comprehensive properties including adaptive features.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": True,
        "complex_valued": True,
        "selective_filtering": True,
        "energy_preserving": False,
        "adaptive_initialization": self.adaptive_initialization,
        "regularization": self.filter_regularization > 0,
        "spectral_dropout": self.spectral_dropout_p > 0,
        "activation": self.activation,
        "parameter_count": 2 * self.sequence_length * self.hidden_dim,
    }

AFNOMixing

AFNOMixing(hidden_dim: int, max_sequence_length: int, modes_seq: int | None = None, modes_hidden: int | None = None, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0)

Bases: MixingLayer

Adaptive Fourier Neural Operator mixing layer.

This layer performs efficient token mixing by applying learnable transformations in the truncated Fourier domain, significantly reducing computational cost while maintaining model expressiveness.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input/output tensors.

required
max_sequence_length int

Maximum sequence length the model will process.

required
modes_seq int | None

Number of Fourier modes to keep in sequence dimension. If None, defaults to max_sequence_length // 2.

None
modes_hidden int | None

Number of Fourier modes to keep in hidden dimension. If None, defaults to hidden_dim // 2.

None
mlp_ratio float

Expansion ratio for the MLP in Fourier domain. Default is 2.0.

2.0
activation str

Activation function for MLP. Default is 'gelu'.

'gelu'
dropout float

Dropout probability for MLP. Default is 0.0.

0.0

Attributes:

Name Type Description
hidden_dim int

Hidden dimension size.

max_sequence_length int

Maximum supported sequence length.

modes_seq int

Number of retained Fourier modes in sequence dimension.

modes_hidden int

Number of retained Fourier modes in hidden dimension.

mlp_ratio float

MLP expansion ratio.

fourier_weight Parameter

Complex-valued learnable weights for Fourier modes.

mlp Sequential

MLP applied in Fourier domain.

Examples:

>>> import torch
>>> layer = AFNOMixing(hidden_dim=768, max_sequence_length=512, modes_seq=128)
>>> x = torch.randn(32, 512, 768)
>>> output = layer(x)
>>> print(output.shape)
torch.Size([32, 512, 768])

Methods:

Name Description
forward

Apply AFNO mixing to input tensor.

get_spectral_properties

Get mathematical properties of AFNO operation.

from_config

Create AFNOMixing layer from configuration.

Source code in spectrans/layers/mixing/afno.py
def __init__(
    self,
    hidden_dim: int,
    max_sequence_length: int,
    modes_seq: int | None = None,
    modes_hidden: int | None = None,
    mlp_ratio: float = 2.0,
    activation: ActivationType = "gelu",
    dropout: float = 0.0,
):
    super().__init__(hidden_dim=hidden_dim, dropout=dropout)

    self.max_sequence_length = max_sequence_length

    # Set default mode truncation if not specified
    self.modes_seq = modes_seq if modes_seq is not None else max_sequence_length // 2
    self.modes_hidden = modes_hidden if modes_hidden is not None else hidden_dim // 2

    # Ensure modes don't exceed actual dimensions (for rfft)
    # For rfft, the last dimension has size n//2 + 1
    self.modes_seq = min(self.modes_seq, max_sequence_length)
    self.modes_hidden = min(self.modes_hidden, hidden_dim // 2 + 1)

    self.mlp_ratio = mlp_ratio

    # Complex-valued learnable weights for Fourier modes
    # We use real FFT, so last dimension is reduced
    scale = 1 / (self.modes_seq * self.modes_hidden)
    self.fourier_weight = nn.Parameter(
        torch.randn(self.modes_seq, self.modes_hidden, 2) * scale
    )

    # MLP in Fourier domain
    mlp_hidden_dim = int(hidden_dim * mlp_ratio)

    # Activation function
    activation_fn: nn.Module
    if activation == "gelu":
        activation_fn = nn.GELU()
    elif activation == "relu":
        activation_fn = nn.ReLU()
    elif activation == "silu":
        activation_fn = nn.SiLU()
    elif activation == "tanh":
        activation_fn = nn.Tanh()
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "identity":
        activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unsupported activation: {activation}")

    # MLP operates on real and imaginary parts concatenated
    self.mlp = nn.Sequential(
        nn.Linear(self.modes_seq * self.modes_hidden * 2, mlp_hidden_dim),
        activation_fn,
        nn.Dropout(dropout),
        nn.Linear(mlp_hidden_dim, self.modes_seq * self.modes_hidden * 2),
        nn.Dropout(dropout),
    )

    # Layer normalization
    self.norm = nn.LayerNorm(hidden_dim)

    self._init_weights()
Functions
forward
forward(x: Tensor) -> Tensor

Apply AFNO mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Output tensor of same shape as input.

Source code in spectrans/layers/mixing/afno.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply AFNO mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Output tensor of same shape as input.
    """
    batch_size, seq_len, hidden_dim = x.shape
    residual = x
    input_dtype = x.dtype

    # Convert to float32 for processing if needed
    if x.dtype != torch.float32:
        x = x.to(torch.float32)
        residual = residual.to(torch.float32)

    # Apply layer norm
    x = self.norm(x)

    # Pad if necessary to match max_sequence_length
    if seq_len < self.max_sequence_length:
        padding = self.max_sequence_length - seq_len
        x = F.pad(x, (0, 0, 0, padding))

    # Step 1: Transform to Fourier space using 2D FFT
    # Treat (sequence, hidden) as 2D spatial dimensions
    # Use safe wrapper to handle MKL issues
    x_ft = safe_rfft2(x, dim=(1, 2), norm="ortho")

    # Step 2: Mode truncation - keep only low-frequency modes
    x_ft_truncated = x_ft[:, : self.modes_seq, : self.modes_hidden]

    # Step 3: Apply learnable transformation in Fourier domain
    # First apply pointwise multiplication with learnable weights
    weight_complex = torch.view_as_complex(self.fourier_weight)
    x_ft_truncated = x_ft_truncated * weight_complex

    # Flatten for MLP processing
    batch_size_ft = x_ft_truncated.shape[0]
    x_ft_flat = torch.view_as_real(x_ft_truncated).reshape(batch_size_ft, -1)

    # Apply MLP
    x_ft_flat = self.mlp(x_ft_flat)

    # Reshape back to complex truncated form
    x_ft_truncated = x_ft_flat.reshape(batch_size_ft, self.modes_seq, self.modes_hidden, 2)
    x_ft_truncated = torch.view_as_complex(x_ft_truncated)

    # Step 4: Zero-pad back to original size
    x_ft_padded = torch.zeros(
        (batch_size, self.max_sequence_length, hidden_dim // 2 + 1),
        dtype=x_ft.dtype,
        device=x_ft.device,
    )
    x_ft_padded[:, : self.modes_seq, : self.modes_hidden] = x_ft_truncated

    # Step 5: Inverse FFT to get back to spatial domain
    # Use safe wrapper to handle MKL issues
    x_spatial = safe_irfft2(
        x_ft_padded, s=(self.max_sequence_length, hidden_dim), dim=(1, 2), norm="ortho"
    )

    # Remove padding if it was added
    if seq_len < self.max_sequence_length:
        x_spatial = x_spatial[:, :seq_len, :]

    # Step 6: Add residual connection
    output = residual + x_spatial

    # Convert back to original dtype if needed
    if output.dtype != input_dtype:
        output = output.to(input_dtype)

    return output
get_spectral_properties
get_spectral_properties() -> dict[str, bool]

Get mathematical properties of AFNO operation.

Returns:

Type Description
dict[str, bool]

Mathematical properties of the transform.

Source code in spectrans/layers/mixing/afno.py
def get_spectral_properties(self) -> dict[str, bool]:
    """Get mathematical properties of AFNO operation.

    Returns
    -------
    dict[str, bool]
        Mathematical properties of the transform.
    """
    return {
        "unitary": False,  # Not unitary due to mode truncation and MLP
        "real_output": True,  # Output is real-valued
        "frequency_domain": True,  # Operations in Fourier domain
        "energy_preserving": False,  # Energy not preserved due to truncation
        "learnable_parameters": True,  # Has learnable weights and MLP
        "translation_equivariant": False,  # Not equivariant due to MLP
        "mode_truncation": True,  # Uses Fourier mode truncation
        "adaptive": True,  # Adaptive filtering based on learned parameters
    }
from_config classmethod
from_config(config: AFNOMixingConfig) -> AFNOMixing

Create AFNOMixing layer from configuration.

Parameters:

Name Type Description Default
config AFNOMixingConfig

Configuration object with layer parameters.

required

Returns:

Type Description
AFNOMixing

Configured AFNO mixing layer.

Source code in spectrans/layers/mixing/afno.py
@classmethod
def from_config(cls, config: "AFNOMixingConfig") -> "AFNOMixing":
    """Create AFNOMixing layer from configuration.

    Parameters
    ----------
    config : AFNOMixingConfig
        Configuration object with layer parameters.

    Returns
    -------
    AFNOMixing
        Configured AFNO mixing layer.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        max_sequence_length=config.max_sequence_length,
        modes_seq=config.modes_seq,
        modes_hidden=config.modes_hidden,
        mlp_ratio=config.mlp_ratio,
        activation=config.activation,
        dropout=config.dropout,
    )

FilterMixingLayer

FilterMixingLayer(hidden_dim: int, sequence_length: int, dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True)

Bases: MixingLayer

Base class for frequency-domain filtering operations.

Filter mixing layers apply learnable filters in the frequency domain, enabling selective emphasis or suppression of frequency components for improved sequence modeling capabilities.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
sequence_length int

Expected sequence length for filter initialization.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filters are learnable parameters.

True

Attributes:

Name Type Description
sequence_length int

Expected sequence length.

learnable_filters bool

Whether filters are learnable.

Methods:

Name Description
get_spectral_properties

Get properties specific to filtering operations.

get_filter_response

Get the frequency response of the current filters.

analyze_frequency_response

Analyze the frequency response characteristics.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
):
    super().__init__(hidden_dim, dropout, norm_eps)
    self.sequence_length = sequence_length
    self.learnable_filters = learnable_filters
Functions
get_spectral_properties
get_spectral_properties() -> dict[str, Any]

Get properties specific to filtering operations.

Returns:

Type Description
dict[str, Any]

Dictionary containing filter-specific properties.

Source code in spectrans/layers/mixing/base.py
def get_spectral_properties(self) -> dict[str, Any]:
    """Get properties specific to filtering operations.

    Returns
    -------
    dict[str, Any]
        Dictionary containing filter-specific properties.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": self.learnable_filters,
        "selective_filtering": True,
        "complex_valued": True,
        "energy_preserving": False,  # Filtering can change energy
    }
get_filter_response abstractmethod
get_filter_response() -> Tensor

Get the frequency response of the current filters.

Returns:

Type Description
Tensor

Complex-valued frequency response of shape matching the filter parameters.

Source code in spectrans/layers/mixing/base.py
@abstractmethod
def get_filter_response(self) -> torch.Tensor:
    """Get the frequency response of the current filters.

    Returns
    -------
    torch.Tensor
        Complex-valued frequency response of shape matching the filter parameters.
    """
    pass
analyze_frequency_response
analyze_frequency_response() -> dict[str, Tensor]

Analyze the frequency response characteristics.

Returns:

Type Description
dict[str, Tensor]

Dictionary containing: - 'magnitude': Magnitude response - 'phase': Phase response - 'group_delay': Group delay response - 'passband_energy': Energy in different frequency bands

Source code in spectrans/layers/mixing/base.py
def analyze_frequency_response(self) -> dict[str, torch.Tensor]:
    """Analyze the frequency response characteristics.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing:
        - 'magnitude': Magnitude response
        - 'phase': Phase response
        - 'group_delay': Group delay response
        - 'passband_energy': Energy in different frequency bands
    """
    response = self.get_filter_response()

    magnitude = torch.abs(response)
    phase = torch.angle(response)

    # Compute group delay (negative derivative of phase)
    # For discrete signals, use finite differences
    phase_diff = torch.diff(phase, dim=-1)
    group_delay = -phase_diff

    # Analyze energy in different frequency bands
    total_energy = torch.sum(magnitude**2, dim=-1, keepdim=True)
    low_freq_energy = torch.sum(
        magnitude[..., : magnitude.size(-1) // 4] ** 2, dim=-1, keepdim=True
    )
    high_freq_energy = torch.sum(
        magnitude[..., 3 * magnitude.size(-1) // 4 :] ** 2, dim=-1, keepdim=True
    )

    return {
        "magnitude": magnitude,
        "phase": phase,
        "group_delay": group_delay,
        "total_energy": total_energy,
        "low_freq_energy": low_freq_energy / (total_energy + self.norm_eps),
        "high_freq_energy": high_freq_energy / (total_energy + self.norm_eps),
    }

FourierMixing

FourierMixing(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)

Bases: UnitaryMixingLayer

FNet-style Fourier mixing layer.

Implements the core FNet mixing operation using 2D Fourier transforms along both sequence and feature dimensions, providing an alternative to attention with \(O(n \log n)\) complexity.

The operation performs: 1. 2D FFT across sequence and feature dimensions 2. Optional real part extraction for final output (original FNet behavior) or keep complex values for full information preservation

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
dropout float

Dropout probability applied after the mixing operation.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

Normalization mode for FFT operations ("forward", "backward", "ortho").

"ortho"
keep_complex bool

If True, keeps complex values from FFT. If False (default), takes only the real part as in original FNet.

False

Attributes:

Name Type Description
fft2d FFT2D

2D Fourier transform module.

keep_complex bool

Whether to keep complex values or extract real part.

Methods:

Name Description
forward

Apply Fourier mixing to input tensor.

get_spectral_properties

Get spectral properties of Fourier mixing.

from_config

Create FourierMixing layer from configuration.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
    keep_complex: bool = False,
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.keep_complex = keep_complex
    # Store transform as non-module attribute to avoid PyTorch module registration
    self.fft2d: FFT2D  # Type annotation for mypy
    object.__setattr__(self, "fft2d", FFT2D(norm=fft_norm))
Functions
forward
forward(x: Tensor) -> Tensor

Apply Fourier mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Mixed tensor of same shape. Complex if keep_complex=True, real values only if keep_complex=False (default).

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply Fourier mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Mixed tensor of same shape. Complex if keep_complex=True,
        real values only if keep_complex=False (default).
    """
    # Apply 2D FFT along last two dimensions (sequence and feature)
    x_freq = self.fft2d.transform(x, dim=(-2, -1))
    # Keep complex values for information preservation or take real part (default)
    x_mixed = x_freq if self.keep_complex else torch.real(x_freq)

    # Apply dropout
    x_mixed = self.dropout(x_mixed)

    return x_mixed  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get spectral properties of Fourier mixing.

Returns:

Type Description
dict[str, str | bool]

Properties including energy preservation and domain information.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get spectral properties of Fourier mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties including energy preservation and domain information.
    """
    return {
        "unitary": False,
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": False,
        "translation_equivariant": True,
        "learnable_parameters": False,
    }
from_config classmethod
from_config(config: FourierMixingConfig) -> FourierMixing

Create FourierMixing layer from configuration.

Parameters:

Name Type Description Default
config FourierMixingConfig

Configuration object with layer parameters.

required

Returns:

Type Description
FourierMixing

Configured Fourier mixing layer.

Source code in spectrans/layers/mixing/fourier.py
@classmethod
def from_config(cls, config: "FourierMixingConfig") -> "FourierMixing":
    """Create FourierMixing layer from configuration.

    Parameters
    ----------
    config : FourierMixingConfig
        Configuration object with layer parameters.

    Returns
    -------
    FourierMixing
        Configured Fourier mixing layer.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        dropout=config.dropout,
        norm_eps=config.norm_eps,
        energy_tolerance=config.energy_tolerance,
        fft_norm=config.fft_norm,
        keep_complex=config.keep_complex,
    )

FourierMixing1D

FourierMixing1D(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho', keep_complex: bool = False)

Bases: UnitaryMixingLayer

1D Fourier mixing along sequence dimension only.

Applies Fourier transform only along the sequence dimension, preserving feature dimension locality while mixing tokens.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
dropout float

Dropout probability applied after the mixing operation.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

Normalization mode for FFT operations.

"ortho"
keep_complex bool

If True, keeps complex values from FFT. If False (default), takes only the real part.

False

Attributes:

Name Type Description
fft1d FFT1D

1D Fourier transform module.

keep_complex bool

Whether to keep complex values or extract real part.

Methods:

Name Description
forward

Apply 1D Fourier mixing to input tensor.

get_spectral_properties

Get spectral properties of 1D Fourier mixing.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
    keep_complex: bool = False,
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.keep_complex = keep_complex
    # Store transform as non-module attribute to avoid PyTorch module registration
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))
Functions
forward
forward(x: Tensor) -> Tensor

Apply 1D Fourier mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Mixed tensor with Fourier transform applied along sequence dimension. Complex if keep_complex=True, real values only if keep_complex=False.

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply 1D Fourier mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Mixed tensor with Fourier transform applied along sequence dimension.
        Complex if keep_complex=True, real values only if keep_complex=False.
    """
    # Apply 1D FFT along sequence dimension only
    x_freq = self.fft1d.transform(x, dim=1)  # sequence dimension

    # Keep complex values or take real part (default behavior)
    x_mixed = x_freq if self.keep_complex else torch.real(x_freq)

    # Apply dropout
    x_mixed = self.dropout(x_mixed)

    return x_mixed  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get spectral properties of 1D Fourier mixing.

Returns:

Type Description
dict[str, str | bool]

Properties specific to 1D sequence mixing.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get spectral properties of 1D Fourier mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties specific to 1D sequence mixing.
    """
    return {
        "unitary": False,  # Real part extraction breaks unitarity
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": False,
        "sequence_mixing_only": True,
        "feature_preserving": True,
        "learnable_parameters": False,
    }

GlobalFilterMixing

GlobalFilterMixing(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)

Bases: FilterMixingLayer

Global Filter Network mixing layer.

Implements the core GFNet mixing operation with learnable complex filters applied in the frequency domain along the sequence dimension.

The layer uses interpolation to adapt filters to different sequence lengths, processing variable-length inputs while preserving learned frequency patterns. This provides resolution independence compared to fixed-size filtering.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of input tensors.

required
sequence_length int

Base sequence length for filter parameter initialization. The filters will be interpolated to match actual input sequence lengths.

required
activation ActivationType

Activation function applied to filter parameters ("sigmoid", "tanh", "identity").

"sigmoid"
dropout float

Dropout probability applied after filtering.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filter parameters are learnable (always True for this class).

True
fft_norm str

FFT normalization mode.

"ortho"
filter_init_std float

Standard deviation for filter parameter initialization.

0.02

Attributes:

Name Type Description
activation str

Activation function name.

filter_real Parameter

Real part of complex filter parameters.

filter_imag Parameter

Imaginary part of complex filter parameters.

fft1d FFT1D

1D FFT transform for sequence dimension.

activation_fn Module

Activation function module (Sigmoid, Tanh, or Identity).

Methods:

Name Description
forward

Apply global filtering to input tensor.

get_filter_response

Get the current frequency response of the filters.

get_spectral_properties

Get spectral properties of global filtering.

from_config

Create GlobalFilterMixing layer from configuration.

Source code in spectrans/layers/mixing/global_filter.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    activation: ActivationType = "sigmoid",
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
    fft_norm: FFTNorm = "ortho",
    filter_init_std: float = 0.02,
):
    super().__init__(hidden_dim, sequence_length, dropout, norm_eps, learnable_filters)
    self.activation = activation

    # Initialize complex filter parameters
    self.filter_real = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)
    self.filter_imag = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)

    # Store FFT transform as non-module attribute
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))

    # Activation function
    if activation == "sigmoid":
        self.activation_fn: Callable[[Tensor], Tensor] = nn.Sigmoid()
    elif activation == "tanh":
        self.activation_fn = nn.Tanh()
    elif activation == "identity":
        self.activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unknown activation: {activation}")
Functions
forward
forward(x: Tensor) -> Tensor

Apply global filtering to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Filtered tensor of same shape as input.

Source code in spectrans/layers/mixing/global_filter.py
def forward(self, x: Tensor) -> Tensor:
    """Apply global filtering to input tensor.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    Tensor
        Filtered tensor of same shape as input.
    """
    # Transform to frequency domain
    x_freq = self.fft1d.transform(x, dim=1)  # Along sequence dimension

    # Get actual sequence length
    seq_len = x_freq.shape[1]

    # Adapt filter to actual sequence length using interpolation
    if seq_len != self.sequence_length:
        # Use interpolation to adapt filters to the actual sequence length
        # This preserves the learned frequency patterns at different resolutions
        filter_real = (
            nn.functional.interpolate(
                self.filter_real.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)

        filter_imag = (
            nn.functional.interpolate(
                self.filter_imag.T.unsqueeze(0),  # (1, hidden_dim, sequence_length)
                size=seq_len,
                mode="linear",
                align_corners=False,
            )
            .squeeze(0)
            .T
        )  # (seq_len, hidden_dim)
    else:
        filter_real = self.filter_real
        filter_imag = self.filter_imag

    # Create complex filter
    filter_complex = make_complex(
        self.activation_fn(filter_real), self.activation_fn(filter_imag)
    )

    # Apply filter in frequency domain (element-wise multiplication)
    filtered_freq = complex_multiply(x_freq, filter_complex)

    # Transform back to time domain
    filtered_time = self.fft1d.inverse_transform(filtered_freq, dim=1)

    # Take real part (assuming real-valued output is desired)
    output = torch.real(filtered_time)

    # Apply dropout
    output = self.dropout(output)

    return output  # type: ignore[no-any-return]
get_filter_response
get_filter_response() -> Tensor

Get the current frequency response of the filters.

Returns:

Type Description
Tensor

Complex-valued frequency response of shape (sequence_length, hidden_dim).

Source code in spectrans/layers/mixing/global_filter.py
def get_filter_response(self) -> Tensor:
    """Get the current frequency response of the filters.

    Returns
    -------
    Tensor
        Complex-valued frequency response of shape (sequence_length, hidden_dim).
    """
    return make_complex(
        self.activation_fn(self.filter_real), self.activation_fn(self.filter_imag)
    )
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool | int]

Get spectral properties of global filtering.

Returns:

Type Description
dict[str, str | bool | int]

Properties including filter characteristics.

Source code in spectrans/layers/mixing/global_filter.py
def get_spectral_properties(self) -> dict[str, str | bool | int]:
    """Get spectral properties of global filtering.

    Returns
    -------
    dict[str, str | bool | int]
        Properties including filter characteristics.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": True,
        "complex_valued": True,
        "selective_filtering": True,
        "energy_preserving": False,  # Filtering can change energy
        "activation": self.activation,
        "parameter_count": 2 * self.sequence_length * self.hidden_dim,
    }
from_config classmethod

Create GlobalFilterMixing layer from configuration.

Parameters:

Name Type Description Default
config GlobalFilterMixingConfig

Configuration object with layer parameters.

required

Returns:

Type Description
GlobalFilterMixing

Configured global filter mixing layer.

Source code in spectrans/layers/mixing/global_filter.py
@classmethod
def from_config(cls, config: "GlobalFilterMixingConfig") -> "GlobalFilterMixing":
    """Create GlobalFilterMixing layer from configuration.

    Parameters
    ----------
    config : GlobalFilterMixingConfig
        Configuration object with layer parameters.

    Returns
    -------
    GlobalFilterMixing
        Configured global filter mixing layer.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        sequence_length=config.sequence_length,
        activation=config.activation,
        dropout=config.dropout,
        learnable_filters=config.learnable_filters,
        fft_norm=config.fft_norm,
        filter_init_std=config.filter_init_std,
    )

GlobalFilterMixing2D

GlobalFilterMixing2D(hidden_dim: int, sequence_length: int, activation: ActivationType = 'sigmoid', dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True, fft_norm: FFTNorm = 'ortho', filter_init_std: float = 0.02)

Bases: FilterMixingLayer

2D Global Filter mixing with filtering along both dimensions.

Extends global filtering to both sequence and feature dimensions, similar to FNet's 2D FFT but with learnable filters.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of input tensors.

required
sequence_length int

Expected sequence length.

required
activation ActivationType

Activation function for filter parameters.

"sigmoid"
dropout float

Dropout probability.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filters are learnable.

True
fft_norm str

FFT normalization mode.

"ortho"
filter_init_std float

Filter parameter initialization standard deviation.

0.02

Attributes:

Name Type Description
filter_real Parameter

Real part of 2D complex filters.

filter_imag Parameter

Imaginary part of 2D complex filters.

fft2d FFT2D

2D FFT transform module.

activation_fn Module

Activation function.

Methods:

Name Description
forward

Apply 2D global filtering.

get_filter_response

Get 2D frequency response.

get_spectral_properties

Get 2D filter properties.

Source code in spectrans/layers/mixing/global_filter.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    activation: ActivationType = "sigmoid",
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
    fft_norm: FFTNorm = "ortho",
    filter_init_std: float = 0.02,
):
    super().__init__(hidden_dim, sequence_length, dropout, norm_eps, learnable_filters)
    self.activation = activation

    # Initialize 2D complex filter parameters
    self.filter_real = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)
    self.filter_imag = nn.Parameter(torch.randn(sequence_length, hidden_dim) * filter_init_std)

    # Store 2D FFT transform as non-module attribute
    self.fft2d: FFT2D  # Type annotation for mypy
    object.__setattr__(self, "fft2d", FFT2D(norm=fft_norm))

    # Activation function
    if activation == "sigmoid":
        self.activation_fn: Callable[[Tensor], Tensor] = nn.Sigmoid()
    elif activation == "tanh":
        self.activation_fn = nn.Tanh()
    elif activation == "identity":
        self.activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unknown activation: {activation}")
Functions
forward
forward(x: Tensor) -> Tensor

Apply 2D global filtering.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Filtered tensor of same shape.

Source code in spectrans/layers/mixing/global_filter.py
def forward(self, x: Tensor) -> Tensor:
    """Apply 2D global filtering.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    Tensor
        Filtered tensor of same shape.
    """
    # Transform to 2D frequency domain
    x_freq = self.fft2d.transform(x, dim=(-2, -1))

    # Get actual dimensions
    seq_len = x_freq.shape[-2]
    hidden = x_freq.shape[-1]

    # Adapt filter to actual dimensions using bilinear interpolation
    if seq_len != self.sequence_length or hidden != self.hidden_dim:
        # Reshape for 2D interpolation
        filter_real = (
            nn.functional.interpolate(
                self.filter_real.unsqueeze(0).unsqueeze(0),  # (1, 1, seq_length, hidden_dim)
                size=(seq_len, hidden),
                mode="bilinear",
                align_corners=False,
            )
            .squeeze(0)
            .squeeze(0)
        )  # (seq_len, hidden)

        filter_imag = (
            nn.functional.interpolate(
                self.filter_imag.unsqueeze(0).unsqueeze(0),  # (1, 1, seq_length, hidden_dim)
                size=(seq_len, hidden),
                mode="bilinear",
                align_corners=False,
            )
            .squeeze(0)
            .squeeze(0)
        )  # (seq_len, hidden)
    else:
        filter_real = self.filter_real
        filter_imag = self.filter_imag

    # Create complex filter
    filter_complex = make_complex(
        self.activation_fn(filter_real), self.activation_fn(filter_imag)
    )

    # Apply 2D filter
    filtered_freq = complex_multiply(x_freq, filter_complex)

    # Transform back to spatial domain
    filtered_spatial = self.fft2d.inverse_transform(filtered_freq, dim=(-2, -1))

    # Take real part
    output = torch.real(filtered_spatial)

    # Apply dropout
    output = self.dropout(output)

    return output  # type: ignore[no-any-return]
get_filter_response
get_filter_response() -> Tensor

Get 2D frequency response.

Returns:

Type Description
Tensor

Complex 2D frequency response.

Source code in spectrans/layers/mixing/global_filter.py
def get_filter_response(self) -> Tensor:
    """Get 2D frequency response.

    Returns
    -------
    Tensor
        Complex 2D frequency response.
    """
    return make_complex(
        self.activation_fn(self.filter_real), self.activation_fn(self.filter_imag)
    )
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool | int]

Get 2D filter properties.

Returns:

Type Description
dict[str, str | bool | int]

2D filtering characteristics.

Source code in spectrans/layers/mixing/global_filter.py
def get_spectral_properties(self) -> dict[str, str | bool | int]:
    """Get 2D filter properties.

    Returns
    -------
    dict[str, str | bool | int]
        2D filtering characteristics.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": True,
        "complex_valued": True,
        "selective_filtering": True,
        "energy_preserving": False,
        "two_dimensional": True,
        "activation": self.activation,
        "parameter_count": 2 * self.sequence_length * self.hidden_dim,
    }

MixingLayer

MixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05)

Bases: SpectralComponent

Base class for spectral mixing operations.

Mixing layers perform token mixing operations using various spectral transforms instead of traditional attention mechanisms. This class provides spectral-specific functionality including mathematical property verification and standardized interfaces for spectral transform operations.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability in normalization.

1e-5

Attributes:

Name Type Description
hidden_dim int

Hidden dimension of the model.

dropout Module

Dropout layer for regularization.

norm_eps float

Epsilon for numerical stability.

Methods:

Name Description
get_spectral_properties

Get mathematical properties of the spectral operation.

verify_shape_consistency

Verify that input and output shapes are consistent.

compute_spectral_norm

Compute spectral norm for analysis and regularization.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
):
    super().__init__()
    self.hidden_dim = hidden_dim
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
    self.norm_eps = norm_eps
Functions
get_spectral_properties abstractmethod
get_spectral_properties() -> dict[str, Any]

Get mathematical properties of the spectral operation.

Returns:

Type Description
dict[str, Any]

Dictionary containing mathematical properties such as: - 'unitary': bool, whether the transform is unitary - 'real_output': bool, whether output is guaranteed real - 'frequency_domain': bool, whether operation occurs in frequency domain - 'energy_preserving': bool, whether energy is preserved

Source code in spectrans/layers/mixing/base.py
@abstractmethod
def get_spectral_properties(self) -> dict[str, Any]:
    """Get mathematical properties of the spectral operation.

    Returns
    -------
    dict[str, Any]
        Dictionary containing mathematical properties such as:
        - 'unitary': bool, whether the transform is unitary
        - 'real_output': bool, whether output is guaranteed real
        - 'frequency_domain': bool, whether operation occurs in frequency domain
        - 'energy_preserving': bool, whether energy is preserved
    """
    pass
verify_shape_consistency
verify_shape_consistency(input_tensor: Tensor, output_tensor: Tensor) -> bool

Verify that input and output shapes are consistent.

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor to the mixing layer.

required
output_tensor Tensor

Output tensor from the mixing layer.

required

Returns:

Type Description
bool

True if shapes are consistent, False otherwise.

Source code in spectrans/layers/mixing/base.py
def verify_shape_consistency(
    self, input_tensor: torch.Tensor, output_tensor: torch.Tensor
) -> bool:
    """Verify that input and output shapes are consistent.

    Parameters
    ----------
    input_tensor : torch.Tensor
        Input tensor to the mixing layer.
    output_tensor : torch.Tensor
        Output tensor from the mixing layer.

    Returns
    -------
    bool
        True if shapes are consistent, False otherwise.
    """
    if input_tensor.shape != output_tensor.shape:
        return False

    # Verify batch dimension consistency
    if input_tensor.size(0) != output_tensor.size(0):
        return False

    # Verify sequence length consistency
    if input_tensor.size(1) != output_tensor.size(1):
        return False

    # Verify hidden dimension consistency
    return input_tensor.size(2) == output_tensor.size(2)
compute_spectral_norm
compute_spectral_norm(tensor: Tensor) -> Tensor

Compute spectral norm for analysis and regularization.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor to compute spectral norm for.

required

Returns:

Type Description
Tensor

Spectral norm of the input tensor.

Source code in spectrans/layers/mixing/base.py
def compute_spectral_norm(self, tensor: torch.Tensor) -> torch.Tensor:
    """Compute spectral norm for analysis and regularization.

    Parameters
    ----------
    tensor : torch.Tensor
        Input tensor to compute spectral norm for.

    Returns
    -------
    torch.Tensor
        Spectral norm of the input tensor.
    """
    # Reshape to matrix for spectral norm computation
    batch_size, seq_len, hidden_dim = tensor.shape
    matrix = tensor.view(batch_size * seq_len, hidden_dim)

    # Compute singular values
    _, s, _ = torch.svd(matrix)

    # Return maximum singular value (spectral norm)
    return torch.max(s, dim=-1)[0].mean()

RealFourierMixing

RealFourierMixing(hidden_dim: int, use_real_fft: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')

Bases: UnitaryMixingLayer

Memory-efficient real Fourier mixing.

Uses real FFT operations to exploit Hermitian symmetry, providing ~2x memory and computational savings for real inputs.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
use_real_fft bool

Whether to use real FFT for efficiency.

True
dropout float

Dropout probability applied after mixing.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

Normalization mode for FFT operations.

"ortho"

Attributes:

Name Type Description
use_real_fft bool

Whether real FFT is enabled.

rfft RFFT

Real FFT transform for sequence dimension.

rfft2d RFFT2D

Real 2D FFT transform for both dimensions.

Methods:

Name Description
forward

Apply real Fourier mixing to input tensor.

get_spectral_properties

Get spectral properties of real Fourier mixing.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    use_real_fft: bool = True,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.use_real_fft = use_real_fft

    if use_real_fft:
        # Type annotations for mypy
        self.rfft: RFFT
        self.rfft2d: RFFT2D
        # Store transforms as non-module attributes
        object.__setattr__(self, "rfft", RFFT(norm=fft_norm))
        object.__setattr__(self, "rfft2d", RFFT2D(norm=fft_norm))
    else:
        # Type annotation for mypy
        self.fft2d: FFT2D
        # Fallback to complex FFT
        object.__setattr__(self, "fft2d", FFT2D(norm=fft_norm))
Functions
forward
forward(x: Tensor) -> Tensor

Apply real Fourier mixing to input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim). Should be real-valued for optimal efficiency.

required

Returns:

Type Description
Tensor

Mixed tensor, guaranteed to be real-valued.

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply real Fourier mixing to input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).
        Should be real-valued for optimal efficiency.

    Returns
    -------
    torch.Tensor
        Mixed tensor, guaranteed to be real-valued.
    """
    if self.use_real_fft and torch.is_floating_point(x):
        # Use real FFT for efficiency
        x_freq = self.rfft2d.transform(x, dim=(-2, -1))
        # Inverse RFFT automatically returns real values
        x_mixed = self.rfft2d.inverse_transform(x_freq, dim=(-2, -1))
    else:
        # Fallback to complex FFT with real part extraction
        x_freq = self.fft2d.transform(x, dim=(-2, -1))
        x_mixed = torch.real(x_freq)

    # Apply dropout
    x_mixed = self.dropout(x_mixed)

    return x_mixed  # type: ignore[no-any-return]
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get spectral properties of real Fourier mixing.

Returns:

Type Description
dict[str, str | bool]

Properties including efficiency characteristics.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get spectral properties of real Fourier mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties including efficiency characteristics.
    """
    return {
        "unitary": self.use_real_fft,  # Real FFT preserves unitarity
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": self.use_real_fft,
        "memory_efficient": self.use_real_fft,
        "hermitian_symmetry": self.use_real_fft,
        "learnable_parameters": False,
    }

SeparableFourierMixing

SeparableFourierMixing(hidden_dim: int, mix_features: bool = True, mix_sequence: bool = True, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001, fft_norm: FFTNorm = 'ortho')

Bases: UnitaryMixingLayer

Separable Fourier mixing with sequence and feature transforms.

Applies separate 1D Fourier transforms along sequence and feature dimensions, which can be more efficient than 2D FFT for certain tensor shapes and provides different mixing characteristics.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the input tensors.

required
mix_features bool

Whether to apply FFT along feature dimension.

True
mix_sequence bool

Whether to apply FFT along sequence dimension.

True
dropout float

Dropout probability.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4
fft_norm str

FFT normalization mode.

"ortho"

Attributes:

Name Type Description
mix_features bool

Whether feature mixing is enabled.

mix_sequence bool

Whether sequence mixing is enabled.

fft1d FFT1D

1D FFT transform module.

Methods:

Name Description
forward

Apply separable Fourier mixing.

get_spectral_properties

Get properties of separable mixing.

Source code in spectrans/layers/mixing/fourier.py
def __init__(
    self,
    hidden_dim: int,
    mix_features: bool = True,
    mix_sequence: bool = True,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
    fft_norm: FFTNorm = "ortho",
):
    super().__init__(hidden_dim, dropout, norm_eps, energy_tolerance)
    self.mix_features = mix_features
    self.mix_sequence = mix_sequence
    # Store transform as non-module attribute
    self.fft1d: FFT1D  # Type annotation for mypy
    object.__setattr__(self, "fft1d", FFT1D(norm=fft_norm))

    if not mix_features and not mix_sequence:
        raise ValueError("At least one of mix_features or mix_sequence must be True")
Functions
forward
forward(x: Tensor) -> Tensor

Apply separable Fourier mixing.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Mixed tensor after applying selected transforms.

Source code in spectrans/layers/mixing/fourier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply separable Fourier mixing.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Mixed tensor after applying selected transforms.
    """
    # Apply sequence mixing (along dim=1)
    if self.mix_sequence:
        x_freq_seq = self.fft1d.transform(x, dim=1)
        x = torch.real(x_freq_seq)

    # Apply feature mixing (along dim=2)
    if self.mix_features:
        x_freq_feat = self.fft1d.transform(x, dim=2)
        x = torch.real(x_freq_feat)

    # Apply dropout
    x = self.dropout(x)

    return x
get_spectral_properties
get_spectral_properties() -> dict[str, str | bool]

Get properties of separable mixing.

Returns:

Type Description
dict[str, str | bool]

Properties reflecting the separable nature.

Source code in spectrans/layers/mixing/fourier.py
def get_spectral_properties(self) -> dict[str, str | bool]:
    """Get properties of separable mixing.

    Returns
    -------
    dict[str, str | bool]
        Properties reflecting the separable nature.
    """
    return {
        "unitary": False,  # Real part extraction
        "real_output": True,
        "frequency_domain": True,
        "energy_preserving": False,
        "separable": True,
        "sequence_mixing": self.mix_sequence,
        "feature_mixing": self.mix_features,
        "learnable_parameters": False,
    }

UnitaryMixingLayer

UnitaryMixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001)

Bases: MixingLayer

Base class for unitary mixing operations.

Unitary mixing layers preserve energy and inner products, maintaining mathematical properties essential for stable training and theoretical guarantees in spectral transformers.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4

Attributes:

Name Type Description
energy_tolerance float

Tolerance for energy preservation checks.

Methods:

Name Description
get_spectral_properties

Get properties specific to unitary transforms.

verify_energy_preservation

Verify energy preservation (Parseval's theorem).

verify_orthogonality

Verify orthogonality of the transform matrix.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
):
    super().__init__(hidden_dim, dropout, norm_eps)
    self.energy_tolerance = energy_tolerance
Functions
get_spectral_properties
get_spectral_properties() -> dict[str, Any]

Get properties specific to unitary transforms.

Returns:

Type Description
dict[str, Any]

Dictionary containing unitary transform properties.

Source code in spectrans/layers/mixing/base.py
def get_spectral_properties(self) -> dict[str, Any]:
    """Get properties specific to unitary transforms.

    Returns
    -------
    dict[str, Any]
        Dictionary containing unitary transform properties.
    """
    return {
        "unitary": True,
        "energy_preserving": True,
        "invertible": True,
        "orthogonal": True,
        "spectrum_preserving": True,
    }
verify_energy_preservation
verify_energy_preservation(input_tensor: Tensor, output_tensor: Tensor) -> bool

Verify energy preservation (Parseval's theorem).

Checks that \(||\mathbf{output}||^2 \approx ||\mathbf{input}||^2\) within tolerance.

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor before transformation.

required
output_tensor Tensor

Output tensor after transformation.

required

Returns:

Type Description
bool

True if energy is preserved within tolerance.

Source code in spectrans/layers/mixing/base.py
def verify_energy_preservation(
    self, input_tensor: torch.Tensor, output_tensor: torch.Tensor
) -> bool:
    r"""Verify energy preservation (Parseval's theorem).

    Checks that $||\mathbf{output}||^2 \approx ||\mathbf{input}||^2$ within tolerance.

    Parameters
    ----------
    input_tensor : torch.Tensor
        Input tensor before transformation.
    output_tensor : torch.Tensor
        Output tensor after transformation.

    Returns
    -------
    bool
        True if energy is preserved within tolerance.
    """
    input_energy = torch.norm(input_tensor, p=2, dim=-1) ** 2
    output_energy = torch.norm(output_tensor, p=2, dim=-1) ** 2

    energy_diff = torch.abs(input_energy - output_energy)
    max_energy = torch.max(input_energy, output_energy)

    # Relative error should be within tolerance
    relative_error = energy_diff / (max_energy + self.norm_eps)

    return bool(torch.all(relative_error < self.energy_tolerance))
verify_orthogonality
verify_orthogonality(transform_matrix: Tensor) -> bool

Verify orthogonality of the transform matrix.

Checks that \(\mathbf{T} \mathbf{T}^H \approx \mathbf{I}\) (identity matrix).

Parameters:

Name Type Description Default
transform_matrix Tensor

Transform matrix to verify.

required

Returns:

Type Description
bool

True if matrix is orthogonal within tolerance.

Source code in spectrans/layers/mixing/base.py
def verify_orthogonality(self, transform_matrix: torch.Tensor) -> bool:
    r"""Verify orthogonality of the transform matrix.

    Checks that $\mathbf{T} \mathbf{T}^H \approx \mathbf{I}$ (identity matrix).

    Parameters
    ----------
    transform_matrix : torch.Tensor
        Transform matrix to verify.

    Returns
    -------
    bool
        True if matrix is orthogonal within tolerance.
    """
    # Compute T @ T^H
    product = torch.matmul(transform_matrix, transform_matrix.conj().transpose(-2, -1))

    # Expected identity matrix
    identity = torch.eye(
        transform_matrix.size(-1), device=transform_matrix.device, dtype=transform_matrix.dtype
    )

    # Check deviation from identity
    deviation = torch.norm(product - identity, p="fro")

    return bool(deviation < self.energy_tolerance)

WaveletMixing

WaveletMixing(hidden_dim: int, wavelet: WaveletType = 'db4', levels: int = 3, mixing_mode: str = 'pointwise', dropout: float = 0.0)

Bases: Module

Token mixing layer using discrete wavelet transform.

Performs mixing in wavelet domain for multi-resolution processing. Decomposes input using DWT, applies learnable mixing to coefficients, and reconstructs the output with residual connections.

Mathematical Formulation

Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times D}\) where \(B\) is batch size, \(N\) is sequence length, and \(D\) is hidden dimension:

Step 1: Channel-wise Decomposition

For each channel \(d \in \{0, 1, \ldots, D-1\}\), extract the channel signal:

\[ \mathbf{x}^{(d)} = \mathbf{X}[:, :, d] \in \mathbb{R}^{B \times N} \]

Apply \(J\)-level DWT decomposition:

\[ \text{DWT}_J(\mathbf{x}^{(d)}) = \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}_{j=1}^J\} \]

Where: - \(\mathbf{c}_{A_J}^{(d)} \in \mathbb{R}^{B \times L_{A_J}}\) are approximation coefficients at level \(J\) - \(\mathbf{c}_{D_j}^{(d)} \in \mathbb{R}^{B \times L_{D_j}}\) are detail coefficients at level \(j\) - \(L_{A_J}\) and \(L_{D_j}\) are coefficient lengths after subsampling

Step 2: Learnable Mixing

Apply mixing transformations based on mode:

Pointwise Mixing (:code:mixing_mode='pointwise'):

\[ \tilde{\mathbf{c}}_{A_J}^{(d)} = \mathbf{c}_{A_J}^{(d)} \odot \mathbf{W}_{A}[:, :L_{A_J}, d] \]
\[ \tilde{\mathbf{c}}_{D_j}^{(d)} = \mathbf{c}_{D_j}^{(d)} \odot \mathbf{W}_{D_j}[:, :L_{D_j}, d] \]

Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times \max(L) \times D}\) are learnable parameters, and \(\odot\) denotes element-wise multiplication with broadcasting.

Channel Mixing (:code:mixing_mode='channel'):

\[ \tilde{\mathbf{c}}_{A_J}^{(d)} = \mathbf{c}_{A_J}^{(d)} \cdot \mathbf{W}_{A}[0, d, d] \]
\[ \tilde{\mathbf{c}}_{D_j}^{(d)} = \mathbf{c}_{D_j}^{(d)} \cdot \mathbf{W}_{D_j}[0, d, d] \]

Where \(\mathbf{W}_{A}, \mathbf{W}_{D_j} \in \mathbb{R}^{1 \times D \times D}\) are initialized as identity matrices.

Level Mixing (:code:mixing_mode='level'):

Cross-level attention is applied to all coefficients simultaneously:

\[ \{\tilde{\mathbf{c}}_{A_J}^{(d)}, \{\tilde{\mathbf{c}}_{D_j}^{(d)}\}_{j=1}^J\} = \text{MultiHeadAttn}(\text{Concat}(\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\})) \]

Step 3: Reconstruction

Reconstruct the signal using inverse DWT:

\[ \tilde{\mathbf{x}}^{(d)} = \text{IDWT}_J(\{\tilde{\mathbf{c}}_{A_J}^{(d)}, \{\tilde{\mathbf{c}}_{D_j}^{(d)}\}_{j=1}^J\}) \]

Apply length adjustment if necessary:

\[ \hat{\mathbf{x}}^{(d)} = \begin{cases} \tilde{\mathbf{x}}^{(d)}[:, :N] & \text{if } |\tilde{\mathbf{x}}^{(d)}| > N \\ \text{Pad}(\tilde{\mathbf{x}}^{(d)}, N) & \text{if } |\tilde{\mathbf{x}}^{(d)}| < N \\ \tilde{\mathbf{x}}^{(d)} & \text{otherwise} \end{cases} \]

Step 4: Residual Connection and Dropout

Combine all channels and apply residual connection:

\[ \hat{\mathbf{X}} = \text{Concat}(\{\hat{\mathbf{x}}^{(d)}\}_{d=0}^{D-1}) \in \mathbb{R}^{B \times N \times D} \]
\[ \mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}}) \]
Complexity Analysis
  • Time Complexity: \(O(NJ) + O(D \cdot N \log N)\) per forward pass

    • \(O(N)\) for DWT/IDWT per level and channel (linear in signal length)
    • \(O(DJ)\) for mixing operations across all levels and channels
    • Dominated by DWT operations when \(J\) is small
  • Space Complexity: \(O(DN + P)\) where \(P\) is parameter count

    • \(O(DN)\) for storing coefficient tensors
    • Parameter count depends on mixing mode:
      • Pointwise: \(P = O(LD)\) where \(L\) is max coefficient length
      • Channel: \(P = O(JD^2)\)
      • Level: \(P = O(D^2)\) for attention parameters
Implementation Notes
  • Uses PyTorch-native DWT implementation for gradient compatibility
  • Dynamic weight slicing ensures proper alignment with variable-length coefficients
  • Perfect reconstruction property maintained through careful length handling
  • Each channel processed independently for computational efficiency

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension size \(D\).

required
wavelet str

Wavelet type (e.g., 'db1', 'db4', 'sym2'). Determines filter bank characteristics.

'db4'
levels int

Number of decomposition levels \(J\). Controls resolution hierarchy.

3
mixing_mode str

Mixing strategy: 'pointwise' (element-wise), 'channel' (diagonal), 'level' (attention).

'pointwise'
dropout float

Dropout probability applied to mixed coefficients before residual connection.

0.0

Attributes:

Name Type Description
dwt DWT1D

Wavelet transform module implementing PyTorch-native DWT/IDWT.

mixing_weights ParameterDict

Learnable parameters for coefficient mixing, structure depends on :attr:mixing_mode.

dropout Dropout

Dropout layer for regularization.

Raises:

Type Description
ValueError

If :attr:mixing_mode is not one of {'pointwise', 'channel', 'level'}.

Examples:

Basic usage with pointwise mixing:

>>> mixer = WaveletMixing(hidden_dim=256, wavelet='db4', levels=3)
>>> x = torch.randn(32, 128, 256)  # (batch, seq_len, hidden)
>>> output = mixer(x)
>>> assert output.shape == x.shape

Channel mixing with identity initialization:

>>> mixer = WaveletMixing(hidden_dim=64, mixing_mode='channel', levels=2)
>>> x = torch.randn(16, 64, 64)
>>> output = mixer(x)
>>> # Initially behaves like identity due to residual connection

Cross-level mixing with attention:

>>> mixer = WaveletMixing(hidden_dim=128, mixing_mode='level', levels=4)
>>> x = torch.randn(8, 256, 128)
>>> output = mixer(x)  # Attention applied across wavelet levels

Methods:

Name Description
forward

Apply wavelet-based mixing following the mathematical formulation.

from_config

Create WaveletMixing from configuration.

Source code in spectrans/layers/mixing/wavelet.py
def __init__(
    self,
    hidden_dim: int,
    wavelet: WaveletType = "db4",
    levels: int = 3,
    mixing_mode: str = "pointwise",
    dropout: float = 0.0,
):
    super().__init__()

    self.hidden_dim = hidden_dim
    self.wavelet = wavelet
    self.levels = levels
    self.mixing_mode = mixing_mode

    # Initialize wavelet transform
    self.dwt = DWT1D(wavelet=wavelet, levels=levels, mode="symmetric")

    # Initialize mixing weights based on mode
    self.mixing_weights = nn.ParameterDict()

    if mixing_mode == "pointwise":
        # Simple pointwise multiplication for each level
        self.mixing_weights["approx"] = nn.Parameter(torch.ones(1, 1, hidden_dim))
        for level in range(levels):
            self.mixing_weights[f"detail_{level}"] = nn.Parameter(torch.ones(1, 1, hidden_dim))

    elif mixing_mode == "channel":
        # Channel-wise mixing matrices
        self.mixing_weights["approx"] = nn.Parameter(torch.eye(hidden_dim).unsqueeze(0))
        for level in range(levels):
            self.mixing_weights[f"detail_{level}"] = nn.Parameter(
                torch.eye(hidden_dim).unsqueeze(0)
            )

    elif mixing_mode == "level":
        # Cross-level mixing with attention-like mechanism
        # Use 1 as embedding dim since we process each channel independently
        self.level_mixer = nn.MultiheadAttention(
            1,
            num_heads=1,
            dropout=dropout,
            batch_first=True,  # Feature dim=1, so only 1 head possible
        )
    else:
        raise ValueError(f"Unknown mixing mode: {mixing_mode}")

    self.dropout = nn.Dropout(dropout)
Functions
forward
forward(x: Tensor) -> Tensor

Apply wavelet-based mixing following the mathematical formulation.

Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual. Each hidden dimension is processed independently to maintain channel separability.

Mathematical Implementation

The forward pass implements the mathematical formulation exactly:

  1. Channel Extraction: \(\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]\) for \(d = 0, \ldots, D-1\)
  2. Wavelet Decomposition: \(\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}\)
  3. Learnable Mixing: Apply mode-specific transformations to coefficients
  4. Signal Reconstruction: \(\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}\)
  5. Channel Concatenation: \(\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]\)
  6. Residual Connection: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape \((B, N, D)\) where:

  • \(B\) is batch size
  • \(N\) is sequence length
  • \(D\) is hidden dimension
required

Returns:

Type Description
Tensor

Mixed output tensor of identical shape \((B, N, D)\) with wavelet-domain mixing applied and residual connection.

Notes
  • Dynamic coefficient length handling ensures robustness to varying sequence lengths
  • Perfect reconstruction property maintained through careful padding/truncation
  • Gradient flow preserved through PyTorch-native operations
Source code in spectrans/layers/mixing/wavelet.py
def forward(self, x: Tensor) -> Tensor:
    r"""Apply wavelet-based mixing following the mathematical formulation.

    Implements the complete wavelet mixing pipeline: decomposition → mixing → reconstruction → residual.
    Each hidden dimension is processed independently to maintain channel separability.

    Mathematical Implementation
    ---------------------------
    The forward pass implements the mathematical formulation exactly:

    1. **Channel Extraction**: $\mathbf{x}^{(d)} = \mathbf{X}[:, :, d]$ for $d = 0, \ldots, D-1$
    2. **Wavelet Decomposition**: $\text{DWT}_J(\mathbf{x}^{(d)}) \rightarrow \{\mathbf{c}_{A_J}^{(d)}, \{\mathbf{c}_{D_j}^{(d)}\}\}$
    3. **Learnable Mixing**: Apply mode-specific transformations to coefficients
    4. **Signal Reconstruction**: $\text{IDWT}_J(\text{mixed coefficients}) \rightarrow \hat{\mathbf{x}}^{(d)}$
    5. **Channel Concatenation**: $\hat{\mathbf{X}} = [\hat{\mathbf{x}}^{(0)}, \ldots, \hat{\mathbf{x}}^{(D-1)}]$
    6. **Residual Connection**: $\mathbf{Y} = \mathbf{X} + \text{Dropout}(\hat{\mathbf{X}})

    Parameters
    ----------
    x : Tensor
        Input tensor of shape $(B, N, D)$ where:

        - $B$ is batch size
        - $N$ is sequence length
        - $D$ is hidden dimension

    Returns
    -------
    Tensor
        Mixed output tensor of identical shape $(B, N, D)$ with wavelet-domain
        mixing applied and residual connection.

    Notes
    -----
    - Dynamic coefficient length handling ensures robustness to varying sequence lengths
    - Perfect reconstruction property maintained through careful padding/truncation
    - Gradient flow preserved through PyTorch-native operations
    """
    _, seq_len, hidden_dim = x.shape

    # Store original input for residual connection
    residual = x

    # Process each hidden dimension independently
    outputs = []
    for h in range(hidden_dim):
        # Extract single channel and squeeze to 2D for DWT
        x_channel = x[:, :, h]  # Shape: [batch, seq_len]

        # Decompose using DWT
        approx, details = self.dwt.decompose(x_channel, dim=-1)

        # Apply mixing based on mode
        if self.mixing_mode == "pointwise":
            # Apply pointwise scaling - need to handle the shape correctly
            # approx shape is [batch, approx_len], weight needs to match
            approx_len = approx.shape[-1]
            approx_weight = self.mixing_weights["approx"][:, :approx_len, h]
            approx_mixed = approx * approx_weight

            details_mixed = []
            for level, detail in enumerate(details):
                detail_len = detail.shape[-1]
                weight = self.mixing_weights[f"detail_{level}"][:, :detail_len, h]
                details_mixed.append(detail * weight)

        elif self.mixing_mode == "channel":
            # Apply channel mixing (simplified for single channel processing)
            approx_mixed = approx * self.mixing_weights["approx"][:, h, h]
            details_mixed = []
            for level, detail in enumerate(details):
                weight = self.mixing_weights[f"detail_{level}"][:, h, h]
                details_mixed.append(detail * weight)

        elif self.mixing_mode == "level":
            # Stack all coefficients for cross-level mixing
            all_coeffs = [approx, *details]
            max_len = max(c.shape[-1] for c in all_coeffs)  # Use -1 for last dimension

            # Pad to same length
            padded_coeffs = []
            for coeff in all_coeffs:
                if coeff.shape[-1] < max_len:  # Use -1 to work with last dimension
                    pad_len = max_len - coeff.shape[-1]
                    coeff = F.pad(coeff, (0, pad_len))  # Pad the last dimension
                padded_coeffs.append(coeff)

            # Stack and apply attention
            stacked = torch.stack(padded_coeffs, dim=1)  # (batch, levels+1, max_len)

            # Reshape for attention: (batch * (levels+1), max_len) -> (batch * (levels+1), max_len, 1) for attention
            batch_size_coeff = stacked.shape[0]
            num_levels = stacked.shape[1]
            seq_len_coeff = stacked.shape[2]

            # Flatten batch and levels, then add feature dimension
            stacked_flat = stacked.view(
                batch_size_coeff * num_levels, seq_len_coeff, 1
            )  # (batch * levels, seq_len, 1)

            # Apply self-attention across sequence positions for each level independently
            mixed_flat, _ = self.level_mixer(stacked_flat, stacked_flat, stacked_flat)

            # Reshape back to separate batch and levels
            mixed = mixed_flat.view(
                batch_size_coeff, num_levels, seq_len_coeff, 1
            )  # Feature dim is 1, not hidden_dim

            # Extract mixed coefficients
            approx_mixed = mixed[
                :, 0, : approx.shape[-1], 0
            ]  # Extract approx coeffs for current channel
            details_mixed = []
            for level in range(self.levels):
                detail_len = details[level].shape[-1]
                detail_mixed = mixed[
                    :, level + 1, :detail_len, 0
                ]  # Extract detail coeffs for current channel
                details_mixed.append(detail_mixed)

        # Reconstruct signal
        reconstructed = self.dwt.reconstruct((approx_mixed, details_mixed), dim=-1)

        # Ensure output has correct length
        if reconstructed.shape[-1] != seq_len:
            if reconstructed.shape[-1] > seq_len:
                reconstructed = reconstructed[:, :seq_len]
            else:
                # Pad if needed
                pad_len = seq_len - reconstructed.shape[-1]
                reconstructed = F.pad(reconstructed, (0, pad_len))

        outputs.append(reconstructed.unsqueeze(-1))  # Add channel dim back

    # Combine all channels
    output = torch.cat(outputs, dim=-1)

    # Apply dropout and residual connection
    output = self.dropout(output)
    output = output + residual

    result: Tensor = output
    return result
from_config classmethod
from_config(config: WaveletMixingConfig) -> WaveletMixing

Create WaveletMixing from configuration.

Parameters:

Name Type Description Default
config WaveletMixingConfig

Typed and validated configuration.

required

Returns:

Type Description
WaveletMixing

Configured instance.

Source code in spectrans/layers/mixing/wavelet.py
@classmethod
def from_config(cls, config: "WaveletMixingConfig") -> "WaveletMixing":
    """Create WaveletMixing from configuration.

    Parameters
    ----------
    config : WaveletMixingConfig
        Typed and validated configuration.

    Returns
    -------
    WaveletMixing
        Configured instance.
    """
    return cls(
        hidden_dim=config.hidden_dim,
        wavelet=config.wavelet,
        levels=config.levels,
        mixing_mode=config.mixing_mode,
        dropout=config.dropout,
    )

WaveletMixing2D

WaveletMixing2D(channels: int, wavelet: WaveletType = 'db4', levels: int = 2, mixing_mode: str = 'subband')

Bases: Module

2D wavelet mixing layer for image-like data.

Performs mixing in 2D wavelet domain, suitable for vision transformers and other architectures processing 2D spatial data. Processes spatial information through multi-resolution wavelet subbands.

Mathematical Formulation

Given input tensor \(\mathbf{X} \in \mathbb{R}^{B \times C \times H \times W}\) where \(B\) is batch size, \(C\) is channels, \(H\) is height, and \(W\) is width:

Step 1: Channel-wise 2D Decomposition

For each channel \(c \in \{0, 1, \ldots, C-1\}\), extract spatial data:

\[ \mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :] \in \mathbb{R}^{B \times H \times W} \]

Apply \(J\)-level 2D DWT decomposition:

\[ \text{DWT2D}_J(\mathbf{X}^{(c)}) = \{\mathbf{LL}_J^{(c)}, \{(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)})\}_{j=1}^J\} \]

Where: - \(\mathbf{LL}_J^{(c)} \in \mathbb{R}^{B \times H_J \times W_J}\) is the approximation subband (low-low) - \(\mathbf{LH}_j^{(c)}, \mathbf{HL}_j^{(c)}, \mathbf{HH}_j^{(c)} \in \mathbb{R}^{B \times H_j \times W_j}\) are detail subbands - \(H_j = \frac{H}{2^j}\), \(W_j = \frac{W}{2^j}\) are spatial dimensions at level \(j\)

Step 2: Subband Mixing

Apply mixing transformations based on mode:

Subband Mixing (:code:mixing_mode='subband'):

Independent processing of each subband using convolutional networks:

\[ \tilde{\mathbf{LL}}_J^{(c)} = f_{LL}(\mathbf{LL}_J^{(c)}) \]
\[ \tilde{\mathbf{LH}}_j^{(c)} = f_{LH}^{(j)}(\mathbf{LH}_j^{(c)}), \quad \tilde{\mathbf{HL}}_j^{(c)} = f_{HL}^{(j)}(\mathbf{HL}_j^{(c)}), \quad \tilde{\mathbf{HH}}_j^{(c)} = f_{HH}^{(j)}(\mathbf{HH}_j^{(c)}) \]

Where \(f_{\cdot}\) are learnable convolutional transformations.

Cross Mixing (:code:mixing_mode='cross'):

Cross-attention across all subbands:

\[ \{\tilde{\mathbf{LL}}_J^{(c)}, \{\tilde{\mathbf{LH}}_j^{(c)}, \tilde{\mathbf{HL}}_j^{(c)}, \tilde{\mathbf{HH}}_j^{(c)}\}\} = \text{CrossAttn}(\text{AllSubbands}^{(c)}) \]

Step 3: 2D Reconstruction

Reconstruct the spatial signal:

\[ \tilde{\mathbf{X}}^{(c)} = \text{IDWT2D}_J(\{\tilde{\mathbf{LL}}_J^{(c)}, \{\tilde{\mathbf{LH}}_j^{(c)}, \tilde{\mathbf{HL}}_j^{(c)}, \tilde{\mathbf{HH}}_j^{(c)}\}\}) \]

Step 4: Channel Concatenation and Residual

\[ \hat{\mathbf{X}} = \text{Stack}(\{\tilde{\mathbf{X}}^{(c)}\}_{c=0}^{C-1}) \in \mathbb{R}^{B \times C \times H \times W} \]
\[ \mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}} \]
Complexity Analysis
  • Time Complexity: \(O(CHW \cdot J) + O(\text{mixing operations})\)
  • Space Complexity: \(O(CHW + \text{subband storage})\)

Where mixing complexity depends on mode: - Subband: \(O(\text{conv operations per subband})\) - Cross: \(O(\text{attention across subbands})\) - Attention: \(O(\text{transformer encoder})\)

Parameters:

Name Type Description Default
channels int

Number of input/output channels \(C\).

required
wavelet str

Wavelet type determining 2D filter bank characteristics.

'db4'
levels int

Number of decomposition levels \(J\).

2
mixing_mode str

Subband mixing strategy: 'subband' (independent), 'cross' (attention), 'attention' (transformer).

'subband'

Attributes:

Name Type Description
dwt DWT2D

2D wavelet transform module.

ll_mixer Sequential

Convolutional network for LL subband (subband mode).

detail_mixers ModuleList

Convolutional networks for detail subbands (subband mode).

cross_mixer MultiheadAttention

Cross-attention module (cross mode).

subband_attention TransformerEncoder

Transformer encoder for subband attention (attention mode).

Raises:

Type Description
ValueError

If :attr:mixing_mode is not one of {'subband', 'cross', 'attention'}.

Examples:

Independent subband processing:

>>> mixer = WaveletMixing2D(channels=256, wavelet='db4', levels=2)
>>> x = torch.randn(32, 256, 64, 64)  # (batch, channels, height, width)
>>> output = mixer(x)
>>> assert output.shape == x.shape

Cross-subband attention:

>>> mixer = WaveletMixing2D(channels=128, mixing_mode='cross', levels=3)
>>> x = torch.randn(16, 128, 128, 128)
>>> output = mixer(x)  # Attention applied across all wavelet subbands

Methods:

Name Description
forward

Apply 2D wavelet-based mixing following the mathematical formulation.

from_config

Create WaveletMixing2D from configuration.

Source code in spectrans/layers/mixing/wavelet.py
def __init__(
    self,
    channels: int,
    wavelet: WaveletType = "db4",
    levels: int = 2,
    mixing_mode: str = "subband",
):
    super().__init__()

    self.channels = channels
    self.wavelet = wavelet
    self.levels = levels
    self.mixing_mode = mixing_mode

    # Initialize 2D wavelet transform
    self.dwt = DWT2D(wavelet=wavelet, levels=levels, mode="symmetric")

    # Initialize mixing layers based on mode
    if mixing_mode == "subband":
        # Independent processing of each subband
        # Each subband from DWT has 1 channel, so conv layers should expect 1 channel input
        self.ll_mixer = nn.Sequential(
            nn.Conv2d(1, 1, 3, padding=1),  # 1 channel in/out for single subband
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )

        self.detail_mixers = nn.ModuleList()
        for _ in range(levels):
            detail_mixer = nn.ModuleDict(
                {
                    "lh": nn.Conv2d(1, 1, 3, padding=1),  # 1 channel in/out per detail subband
                    "hl": nn.Conv2d(1, 1, 3, padding=1),
                    "hh": nn.Conv2d(1, 1, 3, padding=1),
                }
            )
            self.detail_mixers.append(detail_mixer)

    elif mixing_mode == "cross":
        # Cross-subband interaction
        # Each subband is processed per-channel with feature dimension 1 after flattening spatial dims
        # So attention operates on sequences of spatial positions with 1 feature per position
        self.cross_mixer = nn.MultiheadAttention(
            1,
            num_heads=1,
            batch_first=True,  # Feature dim=1, so only 1 head possible
        )

    elif mixing_mode == "attention":
        # Attention-based mixing across all subbands
        # Same as cross mode - feature dimension is 1 after spatial flattening
        self.subband_attention = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=1,  # Feature dimension is 1 after flattening spatial dimensions
                nhead=1,  # Only 1 head possible with d_model=1
                dim_feedforward=4,  # Minimal FFN since d_model=1
                batch_first=True,
            ),
            num_layers=2,
        )
    else:
        raise ValueError(f"Unknown mixing mode: {mixing_mode}")
Functions
forward
forward(x: Tensor) -> Tensor

Apply 2D wavelet-based mixing following the mathematical formulation.

Implements complete 2D wavelet mixing: spatial decomposition → subband mixing → reconstruction → residual connection. Each channel is processed independently.

Mathematical Implementation
  1. Channel Extraction: \(\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]\) for each channel \(c\)
  2. 2D Wavelet Decomposition: \(\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}\)
  3. Subband Mixing: Apply mode-specific transformations to wavelet subbands
  4. 2D Reconstruction: \(\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}\)
  5. Channel Stacking: \(\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]\)
  6. Residual Connection: \(\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}\)

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape \((B, C, H, W)\) where:

  • \(B\) is batch size
  • \(C\) is number of channels
  • \(H\) is height
  • \(W\) is width
required

Returns:

Type Description
Tensor

Mixed output tensor of identical shape \((B, C, H, W)\) with 2D wavelet-domain mixing applied and residual connection.

Notes
  • Spatial dimensions preserved through careful reconstruction handling
  • Different mixing strategies provide various inductive biases
  • Subband mode: Independent processing emphasizes local features
  • Cross mode: Attention enables global subband interactions
  • Attention mode: Full transformer encoder for complex dependencies
Source code in spectrans/layers/mixing/wavelet.py
def forward(self, x: Tensor) -> Tensor:
    r"""Apply 2D wavelet-based mixing following the mathematical formulation.

    Implements complete 2D wavelet mixing: spatial decomposition → subband mixing →
    reconstruction → residual connection. Each channel is processed independently.

    Mathematical Implementation
    ---------------------------
    1. **Channel Extraction**: $\mathbf{X}^{(c)} = \mathbf{X}[:, c, :, :]$ for each channel $c$
    2. **2D Wavelet Decomposition**: $\text{DWT2D}_J(\mathbf{X}^{(c)}) \rightarrow \text{subbands}$
    3. **Subband Mixing**: Apply mode-specific transformations to wavelet subbands
    4. **2D Reconstruction**: $\text{IDWT2D}_J(\text{mixed subbands}) \rightarrow \tilde{\mathbf{X}}^{(c)}$
    5. **Channel Stacking**: $\hat{\mathbf{X}} = [\tilde{\mathbf{X}}^{(0)}, \ldots, \tilde{\mathbf{X}}^{(C-1)}]$
    6. **Residual Connection**: $\mathbf{Y} = \mathbf{X} + \hat{\mathbf{X}}$

    Parameters
    ----------
    x : Tensor
        Input tensor of shape $(B, C, H, W)$ where:

        - $B$ is batch size
        - $C$ is number of channels
        - $H$ is height
        - $W$ is width

    Returns
    -------
    Tensor
        Mixed output tensor of identical shape $(B, C, H, W)$ with 2D wavelet-domain
        mixing applied and residual connection.

    Notes
    -----
    - Spatial dimensions preserved through careful reconstruction handling
    - Different mixing strategies provide various inductive biases
    - Subband mode: Independent processing emphasizes local features
    - Cross mode: Attention enables global subband interactions
    - Attention mode: Full transformer encoder for complex dependencies
    """
    _, channels, height, width = x.shape
    residual = x

    # Process each channel
    outputs = []
    for c in range(channels):
        x_channel = x[:, c : c + 1, :, :]

        # Decompose using 2D DWT
        ll, details = self.dwt.decompose(x_channel, dim=(-2, -1))

        # Apply mixing based on mode
        if self.mixing_mode == "subband":
            # Process LL subband
            ll_mixed = self.ll_mixer(ll)

            # Process detail subbands
            details_mixed = []
            for level, (lh, hl, hh) in enumerate(details):
                mixer = self.detail_mixers[level]
                lh_mixed = mixer["lh"](lh)  # type: ignore
                hl_mixed = mixer["hl"](hl)  # type: ignore
                hh_mixed = mixer["hh"](hh)  # type: ignore
                details_mixed.append((lh_mixed, hl_mixed, hh_mixed))

        elif self.mixing_mode == "cross":
            # Flatten spatial dimensions for attention
            ll_flat = ll.flatten(2).transpose(1, 2)
            details_flat = []
            for lh, hl, hh in details:
                details_flat.extend(
                    [
                        lh.flatten(2).transpose(1, 2),
                        hl.flatten(2).transpose(1, 2),
                        hh.flatten(2).transpose(1, 2),
                    ]
                )

            # Apply cross-attention
            all_subbands = torch.cat([ll_flat, *details_flat], dim=1)
            mixed, _ = self.cross_mixer(all_subbands, all_subbands, all_subbands)

            # Reshape back
            ll_size = ll.shape[2] * ll.shape[3]
            ll_mixed = mixed[:, :ll_size, :].transpose(1, 2).reshape_as(ll)

            details_mixed = []
            offset = ll_size
            for _level, (lh, hl, hh) in enumerate(details):
                lh_size = lh.shape[2] * lh.shape[3]
                hl_size = hl.shape[2] * hl.shape[3]
                hh_size = hh.shape[2] * hh.shape[3]

                lh_mixed = mixed[:, offset : offset + lh_size, :].transpose(1, 2).reshape_as(lh)
                offset += lh_size
                hl_mixed = mixed[:, offset : offset + hl_size, :].transpose(1, 2).reshape_as(hl)
                offset += hl_size
                hh_mixed = mixed[:, offset : offset + hh_size, :].transpose(1, 2).reshape_as(hh)
                offset += hh_size

                details_mixed.append((lh_mixed, hl_mixed, hh_mixed))

        else:  # attention mode
            # Similar to cross but with transformer encoder
            ll_mixed = ll
            details_mixed = details

        # Reconstruct
        reconstructed = self.dwt.reconstruct((ll_mixed, details_mixed), dim=(-2, -1))

        # Ensure correct shape
        if reconstructed.shape[-2:] != (height, width):
            reconstructed = reconstructed[:, :, :height, :width]

        outputs.append(reconstructed)

    # Combine channels
    output = torch.cat(outputs, dim=1)

    # Residual connection
    output = output + residual

    return output
from_config classmethod
from_config(config: WaveletMixing2DConfig) -> WaveletMixing2D

Create WaveletMixing2D from configuration.

Parameters:

Name Type Description Default
config WaveletMixing2DConfig

Typed and validated configuration.

required

Returns:

Type Description
WaveletMixing2D

Configured instance.

Source code in spectrans/layers/mixing/wavelet.py
@classmethod
def from_config(cls, config: "WaveletMixing2DConfig") -> "WaveletMixing2D":
    """Create WaveletMixing2D from configuration.

    Parameters
    ----------
    config : WaveletMixing2DConfig
        Typed and validated configuration.

    Returns
    -------
    WaveletMixing2D
        Configured instance.
    """
    return cls(
        channels=config.channels,
        wavelet=config.wavelet,
        levels=config.levels,
        mixing_mode=config.mixing_mode,
    )

FNOBlock

FNOBlock(hidden_dim: int, modes: int | tuple[int, ...] = 16, mlp_ratio: float = 2.0, activation: ActivationType = 'gelu', dropout: float = 0.0, norm_type: NormType = 'layernorm')

Bases: SpectralComponent

Complete FNO block with spectral convolution and feedforward network.

This block combines the FNO layer with layer normalization, residual connections, and an optional feedforward network for a complete transformer-like block.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension size.

required
modes int | tuple[int, ...]

Number of Fourier modes to retain. Default is 16.

16
mlp_ratio float

Expansion ratio for feedforward network. Default is 2.0.

2.0
activation str

Activation function. Default is 'gelu'.

'gelu'
dropout float

Dropout probability. Default is 0.0.

0.0
norm_type str

Normalization type: 'layer' or 'batch'. Default is 'layer'.

'layernorm'

Attributes:

Name Type Description
hidden_dim int

Hidden dimension size.

fno FourierNeuralOperator

FNO layer for spectral convolution.

norm1 Module

First normalization layer.

norm2 Module | None

Second normalization layer (if FFN is used).

ffn Sequential | None

Feedforward network.

dropout Dropout

Dropout layer.

Examples:

>>> block = FNOBlock(hidden_dim=64, modes=16, mlp_ratio=2.0)
>>> x = torch.randn(32, 128, 64)
>>> output = block(x)
>>> assert output.shape == x.shape

Methods:

Name Description
forward

Apply FNO block.

Source code in spectrans/layers/operators/fno.py
def __init__(
    self,
    hidden_dim: int,
    modes: int | tuple[int, ...] = 16,
    mlp_ratio: float = 2.0,
    activation: ActivationType = "gelu",
    dropout: float = 0.0,
    norm_type: NormType = "layernorm",
):
    super().__init__()

    self.hidden_dim = hidden_dim

    # FNO layer
    self.fno = FourierNeuralOperator(hidden_dim=hidden_dim, modes=modes, activation=activation)

    # Normalization
    self.norm1: nn.Module | None
    self.norm2: nn.Module | None
    self.ffn: nn.Sequential | None
    if norm_type == "layernorm":
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim) if mlp_ratio > 0 else None
    elif norm_type == "batchnorm":
        self.norm1 = nn.BatchNorm1d(hidden_dim)
        self.norm2 = nn.BatchNorm1d(hidden_dim) if mlp_ratio > 0 else None
    elif norm_type == "none":
        self.norm1 = None
        self.norm2 = None
    else:
        raise ValueError(f"Unknown norm_type: {norm_type}")

    # Feedforward network
    if mlp_ratio > 0:
        mlp_hidden = int(hidden_dim * mlp_ratio)
        activation_fn: nn.Module
        if activation == "gelu":
            activation_fn = nn.GELU()
        elif activation == "relu":
            activation_fn = nn.ReLU()
        elif activation == "silu" or activation == "swish":
            activation_fn = nn.SiLU()
        elif activation == "tanh":
            activation_fn = nn.Tanh()
        elif activation == "sigmoid":
            activation_fn = nn.Sigmoid()
        elif activation == "identity":
            activation_fn = nn.Identity()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, mlp_hidden),
            activation_fn,
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, hidden_dim),
            nn.Dropout(dropout),
        )
    else:
        self.ffn = None
        self.norm2 = None

    self.dropout = nn.Dropout(dropout)
Functions
forward
forward(x: Tensor) -> Tensor

Apply FNO block.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Output tensor of same shape as input.

Source code in spectrans/layers/operators/fno.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    r"""Apply FNO block.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Output tensor of same shape as input.
    """
    # FNO with residual connection
    if self.norm1 is not None:
        if isinstance(self.norm1, nn.BatchNorm1d):
            # BatchNorm expects (batch, channels, length)
            x_norm = x.transpose(1, 2)
            x_norm = self.norm1(x_norm)
            x_norm = x_norm.transpose(1, 2)
        else:
            x_norm = self.norm1(x)
    else:
        x_norm = x

    x = x + self.dropout(self.fno(x_norm))

    # Feedforward network with residual connection
    if self.ffn is not None:
        if self.norm2 is not None:
            if isinstance(self.norm2, nn.BatchNorm1d):
                x_norm = x.transpose(1, 2)
                x_norm = self.norm2(x_norm)
                x_norm = x_norm.transpose(1, 2)
            else:
                x_norm = self.norm2(x)
        else:
            x_norm = x

        x = x + self.ffn(x_norm)

    return x

FourierNeuralOperator

FourierNeuralOperator(hidden_dim: int, modes: int | tuple[int, ...] = 16, activation: ActivationType = 'gelu', use_spectral_conv: bool = True, use_linear: bool = True)

Bases: SpectralComponent

Fourier Neural Operator layer for learning operators in function spaces.

This layer combines spectral convolution with pointwise linear transformations to learn mappings between function spaces efficiently.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension (number of channels).

required
modes int | tuple[int, ...]

Number of Fourier modes to retain. Can be an integer for 1D or tuple for higher dimensions. Default is 16.

16
activation str

Activation function. Options: 'gelu', 'relu', 'tanh'. Default is 'gelu'.

'gelu'
use_spectral_conv bool

Whether to use spectral convolution. Default is True.

True
use_linear bool

Whether to use pointwise linear transformation. Default is True.

True

Attributes:

Name Type Description
hidden_dim int

Hidden dimension size.

modes int | tuple[int, ...]

Number of retained Fourier modes.

spectral_conv SpectralConv1d | SpectralConv2d | None

Spectral convolution layer if enabled.

linear Conv1d | Conv2d | None

Pointwise convolution layer if enabled.

activation Module

Activation function.

Examples:

>>> fno = FourierNeuralOperator(hidden_dim=64, modes=16)
>>> x = torch.randn(32, 128, 64)  # (batch, sequence, channels)
>>> output = fno(x)
>>> assert output.shape == x.shape

Methods:

Name Description
forward

Apply Fourier Neural Operator.

Source code in spectrans/layers/operators/fno.py
def __init__(
    self,
    hidden_dim: int,
    modes: int | tuple[int, ...] = 16,
    activation: ActivationType = "gelu",
    use_spectral_conv: bool = True,
    use_linear: bool = True,
):
    super().__init__()

    self.hidden_dim = hidden_dim
    self.modes = modes
    self.use_spectral_conv = use_spectral_conv
    self.use_linear = use_linear

    if not use_spectral_conv and not use_linear:
        raise ValueError("At least one of spectral_conv or linear must be enabled")

    # Determine dimensionality
    if isinstance(modes, int):
        # 1D case
        if use_spectral_conv:
            self.spectral_conv = SpectralConv1d(hidden_dim, hidden_dim, modes)
        else:
            self.spectral_conv = None

        if use_linear:
            self.linear = nn.Conv1d(hidden_dim, hidden_dim, 1)
        else:
            self.linear = None

        self.dim = 1
    elif len(modes) == 2:
        # 2D case
        if use_spectral_conv:
            self.spectral_conv = SpectralConv2d(hidden_dim, hidden_dim, modes)
        else:
            self.spectral_conv = None

        if use_linear:
            self.linear = nn.Conv2d(hidden_dim, hidden_dim, 1)
        else:
            self.linear = None

        self.dim = 2
    else:
        raise ValueError(f"Unsupported modes shape: {modes}")

    # Activation function
    activation_fn: nn.Module
    if activation == "gelu":
        activation_fn = nn.GELU()
    elif activation == "relu":
        activation_fn = nn.ReLU()
    elif activation == "silu" or activation == "swish":
        activation_fn = nn.SiLU()
    elif activation == "tanh":
        activation_fn = nn.Tanh()
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "identity":
        activation_fn = nn.Identity()
    else:
        raise ValueError(f"Unsupported activation: {activation}")

    self.activation = activation_fn

    self._init_weights()
Functions
forward
forward(x: Tensor) -> Tensor

Apply Fourier Neural Operator.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Shape depends on dimensionality: - 1D: (batch_size, sequence_length, hidden_dim) - 2D: (batch_size, height, width, hidden_dim)

required

Returns:

Type Description
Tensor

Output tensor of same shape as input.

Source code in spectrans/layers/operators/fno.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    r"""Apply Fourier Neural Operator.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor. Shape depends on dimensionality:
        - 1D: (batch_size, sequence_length, hidden_dim)
        - 2D: (batch_size, height, width, hidden_dim)

    Returns
    -------
    torch.Tensor
        Output tensor of same shape as input.
    """
    # Ensure all layers match input dtype for proper dtype preservation
    input_dtype = x.dtype
    if self.linear is not None and self.linear.weight.dtype != input_dtype:
        self.linear = self.linear.to(input_dtype)

    if self.dim == 1:
        # For 1D, expect (batch, sequence, channels)
        # Transpose to (batch, channels, sequence) for convolution
        x = x.transpose(-1, -2)

        # Apply spectral convolution and/or linear transformation
        out = torch.zeros_like(x)
        if self.spectral_conv is not None:
            out = out + self.spectral_conv(x)
        if self.linear is not None:
            out = out + self.linear(x)

        # Apply activation
        out = self.activation(out)

        # Transpose back
        out = out.transpose(-1, -2)

    elif self.dim == 2:
        # For 2D, expect (batch, height, width, channels)
        # Permute to (batch, channels, height, width)
        x = x.permute(0, 3, 1, 2)

        # Apply spectral convolution and/or linear transformation
        out = torch.zeros_like(x)
        if self.spectral_conv is not None:
            out = out + self.spectral_conv(x)
        if self.linear is not None:
            out = out + self.linear(x)

        # Apply activation
        out = self.activation(out)

        # Permute back
        out = out.permute(0, 2, 3, 1)

    return out

SpectralConv1d

SpectralConv1d(in_channels: int, out_channels: int, modes: int)

Bases: Module

1D Spectral convolution layer.

Performs convolution in the Fourier domain by element-wise multiplication with learnable complex-valued weights on truncated modes.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes int

Number of Fourier modes to keep (frequency truncation).

required

Attributes:

Name Type Description
in_channels int

Input channel count.

out_channels int

Output channel count.

modes int

Number of retained Fourier modes.

weights Parameter

Complex-valued learnable weights of shape (in_channels, out_channels, modes).

Examples:

>>> conv = SpectralConv1d(in_channels=64, out_channels=64, modes=16)
>>> x = torch.randn(32, 64, 128)  # (batch, channels, sequence)
>>> output = conv(x)
>>> assert output.shape == x.shape

Methods:

Name Description
forward

Apply spectral convolution.

Source code in spectrans/layers/operators/fno.py
def __init__(self, in_channels: int, out_channels: int, modes: int):
    super().__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.modes = modes

    # Complex weights for Fourier modes
    # Scale initialization for stability
    scale = 1 / (in_channels * out_channels)
    self.weights = nn.Parameter(torch.randn(in_channels, out_channels, modes, 2) * scale)
Functions
forward
forward(x: Tensor) -> Tensor

Apply spectral convolution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, in_channels, sequence_length).

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, out_channels, sequence_length).

Source code in spectrans/layers/operators/fno.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    r"""Apply spectral convolution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, in_channels, sequence_length).

    Returns
    -------
    torch.Tensor
        Output tensor of shape (batch_size, out_channels, sequence_length).
    """
    batch_size, _, seq_len = x.shape

    # Compute FFT using safe wrapper
    x_ft = safe_rfft(x, dim=-1)

    # Truncate to retained modes
    x_ft_truncated = x_ft[..., : self.modes]

    # Prepare output in Fourier domain
    out_ft = torch.zeros(
        batch_size, self.out_channels, seq_len // 2 + 1, dtype=x_ft.dtype, device=x.device
    )

    # Apply spectral convolution via complex multiplication
    # Convert weights to complex and match input dtype
    weights_complex = torch.view_as_complex(self.weights.to(x.dtype))

    # Perform einsum for channel mixing with mode-wise multiplication
    # Shape: (batch, in_channels, modes) x (in_channels, out_channels, modes)
    # -> (batch, out_channels, modes)
    out_ft[:, :, : self.modes] = torch.einsum("bim,iom->bom", x_ft_truncated, weights_complex)

    # Inverse FFT to get back to spatial domain using safe wrapper
    out = safe_irfft(out_ft, n=seq_len, dim=-1)

    return out

SpectralConv2d

SpectralConv2d(in_channels: int, out_channels: int, modes: tuple[int, int])

Bases: Module

2D Spectral convolution layer.

Performs 2D convolution in the Fourier domain for image-like data.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
modes tuple[int, int]

Number of Fourier modes to keep in each dimension (height, width).

required

Attributes:

Name Type Description
in_channels int

Input channel count.

out_channels int

Output channel count.

modes1 int

Number of retained modes in first spatial dimension.

modes2 int

Number of retained modes in second spatial dimension.

weights Parameter

Complex weights of shape (in_channels, out_channels, modes1, modes2).

Examples:

>>> conv2d = SpectralConv2d(in_channels=3, out_channels=64, modes=(32, 32))
>>> x = torch.randn(8, 3, 256, 256)
>>> output = conv2d(x)
>>> assert output.shape == (8, 64, 256, 256)

Methods:

Name Description
forward

Apply 2D spectral convolution.

Source code in spectrans/layers/operators/fno.py
def __init__(self, in_channels: int, out_channels: int, modes: tuple[int, int]):
    super().__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.modes1 = modes[0]
    self.modes2 = modes[1]

    # Complex weights for 2D Fourier modes
    scale = 1 / (in_channels * out_channels)
    self.weights = nn.Parameter(
        torch.randn(in_channels, out_channels, self.modes1, self.modes2, 2) * scale
    )
Functions
forward
forward(x: Tensor) -> Tensor

Apply 2D spectral convolution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, in_channels, height, width).

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, out_channels, height, width).

Source code in spectrans/layers/operators/fno.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    r"""Apply 2D spectral convolution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, in_channels, height, width).

    Returns
    -------
    torch.Tensor
        Output tensor of shape (batch_size, out_channels, height, width).
    """
    batch_size, _, h, w = x.shape

    # Compute 2D FFT using safe wrapper
    x_ft = safe_rfft2(x, dim=(-2, -1))

    # Prepare output
    out_ft = torch.zeros(
        batch_size, self.out_channels, h, w // 2 + 1, dtype=x_ft.dtype, device=x.device
    )

    # Truncate and apply convolution
    weights_complex = torch.view_as_complex(self.weights.to(x.dtype))

    # Apply convolution on truncated modes
    out_ft[:, :, : self.modes1, : self.modes2] = torch.einsum(
        "bihw,iohw->bohw", x_ft[:, :, : self.modes1, : self.modes2], weights_complex
    )

    # Inverse FFT using safe wrapper
    out = safe_irfft2(out_ft, s=(h, w), dim=(-2, -1))

    return out