Skip to content

Base Mixing Classes

spectrans.layers.mixing.base

Base classes for spectral mixing layers.

Provides base classes for spectral mixing layers, extending a base MixingLayer with domain-specific functionality for token mixing operations using spectral transforms. These classes define mathematical interfaces and computational requirements for spectral transformers.

Mixing layers implement core token mixing operations that replace traditional attention mechanisms in spectral transformers, providing linear or log-linear computational complexity for sequence modeling tasks.

Classes:

Name Description
MixingLayer

Base class for spectral mixing operations.

UnitaryMixingLayer

Base class for mixing layers that preserve energy (unitary operations).

FilterMixingLayer

Base class for frequency-domain filtering operations.

Examples:

Implementing a custom spectral mixing layer:

>>> from spectrans.layers.mixing.base import MixingLayer
>>> class CustomMixing(MixingLayer):
...     def forward(self, x):
...         # Custom spectral mixing implementation
...         return self.apply_spectral_operation(x)

Creating a unitary mixing layer:

>>> from spectrans.layers.mixing.base import UnitaryMixingLayer
>>> class OrthogonalMixing(UnitaryMixingLayer):
...     def forward(self, x):
...         return self.apply_unitary_transform(x)
...     def verify_unitarity(self, x):
...         # Custom verification logic
...         return True
Notes

Mathematical Properties:

All spectral mixing layers preserve shape where output equals input shape for sequence modeling, support batched processing with consistent behavior, and maintain full gradient flow for end-to-end training.

Unitary mixing layers additionally satisfy energy preservation \(||f(\mathbf{x})||^2 = ||\mathbf{x}||^2\) following Parseval's theorem and preserve inner products through orthogonality.

Filter mixing layers operate in frequency domain, applying learned filters to frequency components with localized operations in frequency space.

See Also

spectrans.core.base : Core base classes for all spectral components spectrans.transforms : Spectral transform implementations used by mixing layers

Classes

MixingLayer

MixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05)

Bases: SpectralComponent

Base class for spectral mixing operations.

Mixing layers perform token mixing operations using various spectral transforms instead of traditional attention mechanisms. This class provides spectral-specific functionality including mathematical property verification and standardized interfaces for spectral transform operations.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability in normalization.

1e-5

Attributes:

Name Type Description
hidden_dim int

Hidden dimension of the model.

dropout Module

Dropout layer for regularization.

norm_eps float

Epsilon for numerical stability.

Methods:

Name Description
get_spectral_properties

Get mathematical properties of the spectral operation.

verify_shape_consistency

Verify that input and output shapes are consistent.

compute_spectral_norm

Compute spectral norm for analysis and regularization.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
):
    super().__init__()
    self.hidden_dim = hidden_dim
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
    self.norm_eps = norm_eps
Functions
get_spectral_properties abstractmethod
get_spectral_properties() -> dict[str, Any]

Get mathematical properties of the spectral operation.

Returns:

Type Description
dict[str, Any]

Dictionary containing mathematical properties such as: - 'unitary': bool, whether the transform is unitary - 'real_output': bool, whether output is guaranteed real - 'frequency_domain': bool, whether operation occurs in frequency domain - 'energy_preserving': bool, whether energy is preserved

Source code in spectrans/layers/mixing/base.py
@abstractmethod
def get_spectral_properties(self) -> dict[str, Any]:
    """Get mathematical properties of the spectral operation.

    Returns
    -------
    dict[str, Any]
        Dictionary containing mathematical properties such as:
        - 'unitary': bool, whether the transform is unitary
        - 'real_output': bool, whether output is guaranteed real
        - 'frequency_domain': bool, whether operation occurs in frequency domain
        - 'energy_preserving': bool, whether energy is preserved
    """
    pass
verify_shape_consistency
verify_shape_consistency(input_tensor: Tensor, output_tensor: Tensor) -> bool

Verify that input and output shapes are consistent.

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor to the mixing layer.

required
output_tensor Tensor

Output tensor from the mixing layer.

required

Returns:

Type Description
bool

True if shapes are consistent, False otherwise.

Source code in spectrans/layers/mixing/base.py
def verify_shape_consistency(
    self, input_tensor: torch.Tensor, output_tensor: torch.Tensor
) -> bool:
    """Verify that input and output shapes are consistent.

    Parameters
    ----------
    input_tensor : torch.Tensor
        Input tensor to the mixing layer.
    output_tensor : torch.Tensor
        Output tensor from the mixing layer.

    Returns
    -------
    bool
        True if shapes are consistent, False otherwise.
    """
    if input_tensor.shape != output_tensor.shape:
        return False

    # Verify batch dimension consistency
    if input_tensor.size(0) != output_tensor.size(0):
        return False

    # Verify sequence length consistency
    if input_tensor.size(1) != output_tensor.size(1):
        return False

    # Verify hidden dimension consistency
    return input_tensor.size(2) == output_tensor.size(2)
compute_spectral_norm
compute_spectral_norm(tensor: Tensor) -> Tensor

Compute spectral norm for analysis and regularization.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor to compute spectral norm for.

required

Returns:

Type Description
Tensor

Spectral norm of the input tensor.

Source code in spectrans/layers/mixing/base.py
def compute_spectral_norm(self, tensor: torch.Tensor) -> torch.Tensor:
    """Compute spectral norm for analysis and regularization.

    Parameters
    ----------
    tensor : torch.Tensor
        Input tensor to compute spectral norm for.

    Returns
    -------
    torch.Tensor
        Spectral norm of the input tensor.
    """
    # Reshape to matrix for spectral norm computation
    batch_size, seq_len, hidden_dim = tensor.shape
    matrix = tensor.view(batch_size * seq_len, hidden_dim)

    # Compute singular values
    _, s, _ = torch.svd(matrix)

    # Return maximum singular value (spectral norm)
    return torch.max(s, dim=-1)[0].mean()

UnitaryMixingLayer

UnitaryMixingLayer(hidden_dim: int, dropout: float = 0.0, norm_eps: float = 1e-05, energy_tolerance: float = 0.0001)

Bases: MixingLayer

Base class for unitary mixing operations.

Unitary mixing layers preserve energy and inner products, maintaining mathematical properties essential for stable training and theoretical guarantees in spectral transformers.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
energy_tolerance float

Tolerance for energy preservation verification.

1e-4

Attributes:

Name Type Description
energy_tolerance float

Tolerance for energy preservation checks.

Methods:

Name Description
get_spectral_properties

Get properties specific to unitary transforms.

verify_energy_preservation

Verify energy preservation (Parseval's theorem).

verify_orthogonality

Verify orthogonality of the transform matrix.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    energy_tolerance: float = 1e-4,
):
    super().__init__(hidden_dim, dropout, norm_eps)
    self.energy_tolerance = energy_tolerance
Functions
get_spectral_properties
get_spectral_properties() -> dict[str, Any]

Get properties specific to unitary transforms.

Returns:

Type Description
dict[str, Any]

Dictionary containing unitary transform properties.

Source code in spectrans/layers/mixing/base.py
def get_spectral_properties(self) -> dict[str, Any]:
    """Get properties specific to unitary transforms.

    Returns
    -------
    dict[str, Any]
        Dictionary containing unitary transform properties.
    """
    return {
        "unitary": True,
        "energy_preserving": True,
        "invertible": True,
        "orthogonal": True,
        "spectrum_preserving": True,
    }
verify_energy_preservation
verify_energy_preservation(input_tensor: Tensor, output_tensor: Tensor) -> bool

Verify energy preservation (Parseval's theorem).

Checks that \(||\mathbf{output}||^2 \approx ||\mathbf{input}||^2\) within tolerance.

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor before transformation.

required
output_tensor Tensor

Output tensor after transformation.

required

Returns:

Type Description
bool

True if energy is preserved within tolerance.

Source code in spectrans/layers/mixing/base.py
def verify_energy_preservation(
    self, input_tensor: torch.Tensor, output_tensor: torch.Tensor
) -> bool:
    r"""Verify energy preservation (Parseval's theorem).

    Checks that $||\mathbf{output}||^2 \approx ||\mathbf{input}||^2$ within tolerance.

    Parameters
    ----------
    input_tensor : torch.Tensor
        Input tensor before transformation.
    output_tensor : torch.Tensor
        Output tensor after transformation.

    Returns
    -------
    bool
        True if energy is preserved within tolerance.
    """
    input_energy = torch.norm(input_tensor, p=2, dim=-1) ** 2
    output_energy = torch.norm(output_tensor, p=2, dim=-1) ** 2

    energy_diff = torch.abs(input_energy - output_energy)
    max_energy = torch.max(input_energy, output_energy)

    # Relative error should be within tolerance
    relative_error = energy_diff / (max_energy + self.norm_eps)

    return bool(torch.all(relative_error < self.energy_tolerance))
verify_orthogonality
verify_orthogonality(transform_matrix: Tensor) -> bool

Verify orthogonality of the transform matrix.

Checks that \(\mathbf{T} \mathbf{T}^H \approx \mathbf{I}\) (identity matrix).

Parameters:

Name Type Description Default
transform_matrix Tensor

Transform matrix to verify.

required

Returns:

Type Description
bool

True if matrix is orthogonal within tolerance.

Source code in spectrans/layers/mixing/base.py
def verify_orthogonality(self, transform_matrix: torch.Tensor) -> bool:
    r"""Verify orthogonality of the transform matrix.

    Checks that $\mathbf{T} \mathbf{T}^H \approx \mathbf{I}$ (identity matrix).

    Parameters
    ----------
    transform_matrix : torch.Tensor
        Transform matrix to verify.

    Returns
    -------
    bool
        True if matrix is orthogonal within tolerance.
    """
    # Compute T @ T^H
    product = torch.matmul(transform_matrix, transform_matrix.conj().transpose(-2, -1))

    # Expected identity matrix
    identity = torch.eye(
        transform_matrix.size(-1), device=transform_matrix.device, dtype=transform_matrix.dtype
    )

    # Check deviation from identity
    deviation = torch.norm(product - identity, p="fro")

    return bool(deviation < self.energy_tolerance)

FilterMixingLayer

FilterMixingLayer(hidden_dim: int, sequence_length: int, dropout: float = 0.0, norm_eps: float = 1e-05, learnable_filters: bool = True)

Bases: MixingLayer

Base class for frequency-domain filtering operations.

Filter mixing layers apply learnable filters in the frequency domain, enabling selective emphasis or suppression of frequency components for improved sequence modeling capabilities.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
sequence_length int

Expected sequence length for filter initialization.

required
dropout float

Dropout probability for regularization.

0.0
norm_eps float

Epsilon for numerical stability.

1e-5
learnable_filters bool

Whether filters are learnable parameters.

True

Attributes:

Name Type Description
sequence_length int

Expected sequence length.

learnable_filters bool

Whether filters are learnable.

Methods:

Name Description
get_spectral_properties

Get properties specific to filtering operations.

get_filter_response

Get the frequency response of the current filters.

analyze_frequency_response

Analyze the frequency response characteristics.

Source code in spectrans/layers/mixing/base.py
def __init__(
    self,
    hidden_dim: int,
    sequence_length: int,
    dropout: float = 0.0,
    norm_eps: float = 1e-5,
    learnable_filters: bool = True,
):
    super().__init__(hidden_dim, dropout, norm_eps)
    self.sequence_length = sequence_length
    self.learnable_filters = learnable_filters
Functions
get_spectral_properties
get_spectral_properties() -> dict[str, Any]

Get properties specific to filtering operations.

Returns:

Type Description
dict[str, Any]

Dictionary containing filter-specific properties.

Source code in spectrans/layers/mixing/base.py
def get_spectral_properties(self) -> dict[str, Any]:
    """Get properties specific to filtering operations.

    Returns
    -------
    dict[str, Any]
        Dictionary containing filter-specific properties.
    """
    return {
        "frequency_domain": True,
        "learnable_filters": self.learnable_filters,
        "selective_filtering": True,
        "complex_valued": True,
        "energy_preserving": False,  # Filtering can change energy
    }
get_filter_response abstractmethod
get_filter_response() -> Tensor

Get the frequency response of the current filters.

Returns:

Type Description
Tensor

Complex-valued frequency response of shape matching the filter parameters.

Source code in spectrans/layers/mixing/base.py
@abstractmethod
def get_filter_response(self) -> torch.Tensor:
    """Get the frequency response of the current filters.

    Returns
    -------
    torch.Tensor
        Complex-valued frequency response of shape matching the filter parameters.
    """
    pass
analyze_frequency_response
analyze_frequency_response() -> dict[str, Tensor]

Analyze the frequency response characteristics.

Returns:

Type Description
dict[str, Tensor]

Dictionary containing: - 'magnitude': Magnitude response - 'phase': Phase response - 'group_delay': Group delay response - 'passband_energy': Energy in different frequency bands

Source code in spectrans/layers/mixing/base.py
def analyze_frequency_response(self) -> dict[str, torch.Tensor]:
    """Analyze the frequency response characteristics.

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary containing:
        - 'magnitude': Magnitude response
        - 'phase': Phase response
        - 'group_delay': Group delay response
        - 'passband_energy': Energy in different frequency bands
    """
    response = self.get_filter_response()

    magnitude = torch.abs(response)
    phase = torch.angle(response)

    # Compute group delay (negative derivative of phase)
    # For discrete signals, use finite differences
    phase_diff = torch.diff(phase, dim=-1)
    group_delay = -phase_diff

    # Analyze energy in different frequency bands
    total_energy = torch.sum(magnitude**2, dim=-1, keepdim=True)
    low_freq_energy = torch.sum(
        magnitude[..., : magnitude.size(-1) // 4] ** 2, dim=-1, keepdim=True
    )
    high_freq_energy = torch.sum(
        magnitude[..., 3 * magnitude.size(-1) // 4 :] ** 2, dim=-1, keepdim=True
    )

    return {
        "magnitude": magnitude,
        "phase": phase,
        "group_delay": group_delay,
        "total_energy": total_energy,
        "low_freq_energy": low_freq_energy / (total_energy + self.norm_eps),
        "high_freq_energy": high_freq_energy / (total_energy + self.norm_eps),
    }