Skip to content

Layer Configuration

spectrans.config.layers

Layer configuration schemas for spectral transformers.

This module provides Pydantic configuration models for all layer types in spectrans, enabling type-safe configuration of mixing layers, attention mechanisms, and neural operators. Each configuration class validates parameters and provides sensible defaults for production use.

Modules:

Name Description
attention

Configuration models for attention layers.

mixing

Configuration models for mixing layers.

operators

Configuration models for operator layers.

Classes:

Name Description
AFNOMixingConfig

Configuration for Adaptive Fourier Neural Operator mixing.

DCTAttentionConfig

Configuration for DCT-based attention.

FourierMixingConfig

Configuration for Fourier mixing layers.

GlobalFilterMixingConfig

Configuration for global filter networks.

HadamardAttentionConfig

Configuration for Hadamard attention.

LSTAttentionConfig

Configuration for Linear Spectral Transform attention.

MixedTransformAttentionConfig

Configuration for mixed transform attention.

SpectralAttentionConfig

Configuration for spectral attention with RFF.

SpectralKernelAttentionConfig

Configuration for kernel-based spectral attention.

WaveletMixing2DConfig

Configuration for 2D wavelet mixing.

WaveletMixingConfig

Configuration for 1D wavelet mixing.

Examples:

Configuring a Fourier mixing layer:

>>> from spectrans.config.layers import FourierMixingConfig
>>>
>>> config = FourierMixingConfig(
...     hidden_dim=768,
...     dropout=0.1,
...     use_real_fft=True
... )
>>> print(config.model_dump())

Configuring spectral attention:

>>> from spectrans.config.layers import SpectralAttentionConfig
>>>
>>> config = SpectralAttentionConfig(
...     hidden_dim=512,
...     num_heads=8,
...     num_features=256,
...     kernel_type="gaussian"
... )

Validation example:

>>> from spectrans.config.layers import GlobalFilterMixingConfig
>>>
>>> # This will raise a validation error
>>> try:
...     config = GlobalFilterMixingConfig(
...         hidden_dim=-1,  # Invalid dimension
...         sequence_length=512
...     )
>>> except ValueError as e:
...     print(f"Validation error: {e}")
Notes

Configuration Validation:

All configuration classes perform: - Range validation (e.g., dimensions > 0) - Type coercion where appropriate - Default value assignment - Cross-field validation where needed

Common Parameters:

  • hidden_dim: Model hidden dimension
  • dropout: Dropout probability (0.0-1.0)
  • bias: Whether to include bias terms
  • activation: Activation function type

Layer-Specific Parameters:

  • Mixing layers: sequence_length, normalization
  • Attention layers: num_heads, num_features, kernel_type
  • Operator layers: modes, grid_size, lifting_dim
See Also

spectrans.layers : Actual layer implementations. spectrans.config : Parent configuration module. spectrans.config.models : Model configuration schemas.

Classes

DCTAttentionConfig

Bases: AttentionLayerConfig

Configuration for DCT-based attention layer.

Attributes:

Name Type Description
dct_type int

Type of DCT transform (typically 2), defaults to 2.

learnable_scale bool

Whether to use learnable diagonal scaling, defaults to True.

HadamardAttentionConfig

Bases: AttentionLayerConfig

Configuration for Hadamard-based attention layer.

Attributes:

Name Type Description
scale_by_sqrt bool

Whether to scale by sqrt(n), defaults to True.

learnable_scale bool

Whether to use learnable diagonal scaling, defaults to True.

LSTAttentionConfig

Bases: AttentionLayerConfig

Configuration for Linear Spectral Transform Attention.

Attributes:

Name Type Description
transform_type TransformLSTType

Type of spectral transform ('dct', 'dst', 'hadamard', 'mixed'), defaults to 'dct'.

learnable_scale bool

Whether to use learnable diagonal scaling, defaults to True.

normalize bool

Whether to normalize transform output, defaults to True.

use_bias bool

Whether to use bias in projections, defaults to True.

MixedTransformAttentionConfig

Bases: AttentionLayerConfig

Configuration for mixed transform attention layer.

Attributes:

Name Type Description
use_fft bool

Whether to use FFT transforms, defaults to True.

use_dct bool

Whether to use DCT transforms, defaults to True.

use_hadamard bool

Whether to use Hadamard transforms, defaults to True.

SpectralAttentionConfig

Bases: AttentionLayerConfig

Configuration for Spectral Attention with Random Fourier Features.

Attributes:

Name Type Description
num_features int | None

Number of random Fourier features, defaults to None (uses head_dim).

kernel_type KernelType

Type of kernel ('gaussian' or 'softmax'), defaults to 'softmax'.

use_orthogonal bool

Whether to use orthogonal random features, defaults to True.

feature_redraw bool

Whether to redraw features during training, defaults to False.

use_bias bool

Whether to use bias in projections, defaults to True.

SpectralKernelAttentionConfig

Bases: AttentionLayerConfig

Configuration for spectral kernel attention.

Attributes:

Name Type Description
kernel_type SpectralKernelType

Type of spectral kernel ('gaussian', 'polynomial', 'spectral'), defaults to 'gaussian'.

rank int | None

Rank for low-rank approximation, defaults to None (uses min(64, head_dim)).

num_features int | None

Number of features for approximation, defaults to None.

AFNOMixingConfig

Bases: BaseLayerConfig

Configuration for Adaptive Fourier Neural Operator mixing layers.

Attributes:

Name Type Description
max_sequence_length int

Maximum sequence length for mode truncation, must be positive.

modes_seq int | None

Number of Fourier modes in sequence dimension, defaults to max_sequence_length // 2.

modes_hidden int | None

Number of Fourier modes in hidden dimension, defaults to hidden_dim // 2.

mlp_ratio float

MLP expansion ratio in frequency domain, defaults to 2.0.

activation ActivationType

Activation function for MLP, defaults to "gelu".

FourierMixingConfig

Bases: UnitaryLayerConfig

Configuration for standard Fourier mixing layers.

Attributes:

Name Type Description
keep_complex bool

If True, keeps complex values from FFT. If False (default), takes only the real part as in original FNet.

GlobalFilterMixingConfig

Bases: FilterLayerConfig

Configuration for global filter mixing layers.

Attributes:

Name Type Description
activation ActivationType

Activation function for filters, defaults to "sigmoid".

WaveletMixing2DConfig

Bases: BaseModel

Configuration model for WaveletMixing2D layer.

Attributes:

Name Type Description
channels int

Number of input channels, must be positive.

wavelet WaveletType

Wavelet family name, defaults to "db4".

levels int

Number of decomposition levels, must be between 1 and 6, defaults to 2.

mixing_mode Literal['subband', 'channel']

2D mixing operation mode, defaults to "subband".

Methods:

Name Description
validate_wavelet

Validate that wavelet name is supported.

Functions
validate_wavelet classmethod
validate_wavelet(v: WaveletType) -> WaveletType

Validate that wavelet name is supported.

Source code in spectrans/config/layers/mixing.py
@field_validator("wavelet")
@classmethod
def validate_wavelet(cls, v: WaveletType) -> WaveletType:
    """Validate that wavelet name is supported."""
    if not v:
        raise ValueError("Wavelet name cannot be empty")
    return v

WaveletMixingConfig

Bases: BaseLayerConfig

Configuration model for WaveletMixing layer.

Attributes:

Name Type Description
wavelet WaveletType

Wavelet family name, defaults to "db4".

levels int

Number of decomposition levels, must be between 1 and 6, defaults to 3.

mixing_mode Literal['pointwise', 'subband']

Mixing operation mode, defaults to "pointwise".

Methods:

Name Description
validate_wavelet

Validate that wavelet name is supported.

Functions
validate_wavelet classmethod
validate_wavelet(v: WaveletType) -> WaveletType

Validate that wavelet name is supported.

Source code in spectrans/config/layers/mixing.py
@field_validator("wavelet")
@classmethod
def validate_wavelet(cls, v: WaveletType) -> WaveletType:
    """Validate that wavelet name is supported."""
    if not v:
        raise ValueError("Wavelet name cannot be empty")
    return v