Skip to content

Core Components

spectrans.core

Core components and interfaces for the spectrans library.

This module provides the fundamental building blocks for spectral transformer implementations, including abstract base classes, type definitions, and the component registry system. All spectral transformer components inherit from these base classes to ensure consistent APIs and enable modular composition through the registry.

The core module establishes the mathematical foundations and software architecture that allows for flexible experimentation with different spectral transform combinations while maintaining type safety and performance.

Modules:

Name Description
base

Base classes and interfaces for all spectral components.

registry

Component registration and discovery system.

types

Type definitions and aliases for the library.

Classes:

Name Description
AttentionLayer

Base class for spectral attention mechanisms.

BaseModel

Base class for spectral transformer models.

ComponentRegistry

Registry system for dynamic component discovery and instantiation.

SpectralComponent

Abstract base class for all spectral neural network components.

TransformerBlock

Base class for transformer blocks with residual connections.

Functions:

Name Description
create_component

Factory function to create registered component instances.

get_component

Retrieve component class from registry.

list_components

List all registered components in a category.

register_component

Decorator to register components in the global registry.

Attributes:

Name Type Description
registry ComponentRegistry

Global component registry instance.

Examples:

Using the component registry system:

>>> from spectrans.core import register_component, create_component
>>> from spectrans.layers.mixing.base import MixingLayer
>>> @register_component('mixing', 'custom')
... class CustomMixing(MixingLayer):
...     def forward(self, x):
...         return x  # Custom implementation
>>> mixing = create_component('mixing', 'custom', hidden_dim=768)

Working with base classes for type safety:

>>> import torch
>>> from spectrans.core import SpectralComponent
>>> def process_component(component: SpectralComponent, input_tensor: torch.Tensor):
...     output = component(input_tensor)
...     return output
Notes

The core architecture follows these design principles:

  1. Abstract Interfaces: All components implement consistent forward() methods
  2. Type Safety: Type hints with Python 3.13 syntax
  3. Modularity: Registry system enables runtime component composition
  4. Extensibility: Easy to add new transforms and mixing strategies

The registry system supports six categories of components: - transform: Spectral transforms (FFT, DCT, DWT, Hadamard) - mixing: Token mixing layers (FourierMixing, GlobalFilter, etc.) - attention: Spectral attention mechanisms - block: Transformer blocks - model: Model implementations - kernel: Kernel functions for attention approximation

See Also

spectrans.core.base : Base class definitions and interfaces. spectrans.core.types : Type aliases and definitions. spectrans.core.registry : Component registration system.

Classes

AttentionLayer

AttentionLayer(hidden_dim: int, num_heads: int = 1, dropout: float = 0.0)

Bases: SpectralComponent

Base class for attention layers.

Attention layers implement various forms of spectral attention mechanisms as alternatives to standard multi-head attention.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
num_heads int

Number of attention heads.

1
dropout float

Dropout probability.

0.0

Attributes:

Name Type Description
hidden_dim int

Hidden dimension of the model.

num_heads int

Number of attention heads.

dropout Module

Dropout layer or identity if dropout is 0.

Source code in spectrans/core/base.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 1,
    dropout: float = 0.0,
):
    super().__init__()
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

BaseModel

BaseModel(num_layers: int, hidden_dim: int, max_seq_length: int = 512, vocab_size: int | None = None, num_classes: int | None = None, dropout: float = 0.0)

Bases: Module

Base class for spectral transformer models.

This class provides common functionality for all spectral transformer model variants.

Parameters:

Name Type Description Default
num_layers int

Number of transformer layers.

required
hidden_dim int

Hidden dimension of the model.

required
max_seq_length int

Maximum sequence length.

512
vocab_size int | None

Vocabulary size for embedding layer. If None, no embedding is used.

None
num_classes int | None

Number of output classes. If None, no classification head is used.

None
dropout float

Dropout probability.

0.0

Attributes:

Name Type Description
num_layers int

Number of transformer layers.

hidden_dim int

Hidden dimension of the model.

max_seq_length int

Maximum sequence length.

vocab_size int | None

Vocabulary size.

num_classes int | None

Number of output classes.

embedding Embedding | None

Optional embedding layer.

pos_embedding Parameter

Positional embedding parameters.

dropout Module

Dropout layer.

classifier Linear | None

Optional classification head.

blocks ModuleList

List of transformer blocks (populated by subclasses).

Methods:

Name Description
forward

Forward pass through the model.

Source code in spectrans/core/base.py
def __init__(
    self,
    num_layers: int,
    hidden_dim: int,
    max_seq_length: int = 512,
    vocab_size: int | None = None,
    num_classes: int | None = None,
    dropout: float = 0.0,
):
    super().__init__()
    self.num_layers = num_layers
    self.hidden_dim = hidden_dim
    self.max_seq_length = max_seq_length
    self.vocab_size = vocab_size
    self.num_classes = num_classes

    # Optional embedding layer
    self.embedding = nn.Embedding(vocab_size, hidden_dim) if vocab_size else None

    # Positional embedding
    self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_length, hidden_dim))

    # Dropout
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

    # Optional classification head
    self.classifier = nn.Linear(hidden_dim, num_classes) if num_classes else None

    # Transformer blocks will be defined in subclasses
    self.blocks = nn.ModuleList()
Functions
forward
forward(x: Tensor, mask: Tensor | None = None) -> Tensor

Forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

Input tensor. Shape depends on whether embedding is used: - With embedding: (batch_size, sequence_length) containing token indices - Without embedding: (batch_size, sequence_length, hidden_dim)

required
mask Tensor | None

Optional attention mask of shape (batch_size, sequence_length).

None

Returns:

Type Description
Tensor

Output tensor. Shape depends on whether classifier is used: - With classifier: (batch_size, num_classes) - Without classifier: (batch_size, sequence_length, hidden_dim)

Source code in spectrans/core/base.py
def forward(
    self,
    x: torch.Tensor,
    mask: torch.Tensor | None = None,  # noqa: ARG002
) -> torch.Tensor:
    """Forward pass through the model.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor. Shape depends on whether embedding is used:
        - With embedding: (batch_size, sequence_length) containing token indices
        - Without embedding: (batch_size, sequence_length, hidden_dim)
    mask : torch.Tensor | None, default=None
        Optional attention mask of shape (batch_size, sequence_length).

    Returns
    -------
    torch.Tensor
        Output tensor. Shape depends on whether classifier is used:
        - With classifier: (batch_size, num_classes)
        - Without classifier: (batch_size, sequence_length, hidden_dim)
    """
    # Apply embedding if needed
    if self.embedding is not None:
        x = self.embedding(x)

    # Add positional embeddings
    seq_len = x.size(1)
    x = x + self.pos_embedding[:, :seq_len, :]

    # Apply dropout
    x = self.dropout(x)

    # Pass through transformer blocks
    for block in self.blocks:
        x = block(x)

    # Apply classification head if needed
    if self.classifier is not None:
        # Use [CLS] token or mean pooling
        x = x.mean(dim=1)  # Mean pooling
        x = self.classifier(x)

    return x

SpectralComponent

Bases: Module, ABC

Base class for all spectral components.

This abstract base class defines the interface that all spectral transformer components must implement.

Methods:

Name Description
forward

Forward pass.

Functions
forward abstractmethod
forward(x: Tensor, *args, **kwargs) -> Tensor | tuple[Tensor, ...]

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required
*args Any

Additional positional arguments for subclass-specific parameters.

()
**kwargs Any

Additional keyword arguments for subclass-specific parameters.

{}

Returns:

Type Description
Tensor | tuple[Tensor, ...]

Output tensor(s). Single tensor for most cases, tuple for attention layers that optionally return attention weights.

Source code in spectrans/core/base.py
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, ...]:  # type: ignore[no-untyped-def]
    """Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).
    *args : Any
        Additional positional arguments for subclass-specific parameters.
    **kwargs : Any
        Additional keyword arguments for subclass-specific parameters.

    Returns
    -------
    torch.Tensor | tuple[torch.Tensor, ...]
        Output tensor(s). Single tensor for most cases, tuple for attention
        layers that optionally return attention weights.
    """
    pass

TransformerBlock

TransformerBlock(mixing_layer: MixingLayer | AttentionLayer, ffn: Module | None = None, norm_layer: type[Module] = LayerNorm, dropout: float = 0.0)

Bases: SpectralComponent

Base class for transformer blocks.

Transformer blocks combine mixing/attention layers with feedforward networks and normalization to form complete transformer layers.

Parameters:

Name Type Description Default
mixing_layer MixingLayer | AttentionLayer

The mixing or attention layer for token interactions.

required
ffn Module | None

Feedforward network module. If None, no FFN is used.

None
norm_layer type[Module]

Normalization layer class to use.

nn.LayerNorm
dropout float

Dropout probability for residual connections.

0.0

Attributes:

Name Type Description
mixing_layer MixingLayer | AttentionLayer

The mixing or attention layer.

ffn Module | None

Feedforward network module.

hidden_dim int

Hidden dimension extracted from mixing layer.

norm1 Module

First normalization layer.

norm2 Module | None

Second normalization layer (if FFN is used).

dropout Module

Dropout layer for residual connections.

Methods:

Name Description
forward

Forward pass through transformer block.

Source code in spectrans/core/base.py
def __init__(
    self,
    mixing_layer: "MixingLayer | AttentionLayer",
    ffn: nn.Module | None = None,
    norm_layer: type[nn.Module] = nn.LayerNorm,
    dropout: float = 0.0,
):
    super().__init__()
    self.mixing_layer = mixing_layer
    self.ffn = ffn

    # Get hidden dimension from mixing layer
    self.hidden_dim = mixing_layer.hidden_dim

    # Setup normalization layers
    self.norm1 = norm_layer(self.hidden_dim)
    self.norm2 = norm_layer(self.hidden_dim) if ffn is not None else None

    # Dropout for residual connections
    self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
Functions
forward
forward(x: Tensor) -> Tensor

Forward pass through transformer block.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, sequence_length, hidden_dim).

required

Returns:

Type Description
Tensor

Output tensor of same shape as input.

Source code in spectrans/core/base.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through transformer block.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, sequence_length, hidden_dim).

    Returns
    -------
    torch.Tensor
        Output tensor of same shape as input.
    """
    # Mixing/attention with residual connection
    residual = x
    x = self.norm1(x)
    x = self.mixing_layer(x)
    x = self.dropout(x)
    x = residual + x

    # FFN with residual connection (if FFN exists)
    if self.ffn is not None and self.norm2 is not None:
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = self.dropout(x)
        x = residual + x

    return x

ComponentRegistry

ComponentRegistry()

Registry for dynamically registering and retrieving components.

This registry allows for flexible component registration and retrieval, enabling users to easily extend the library with custom implementations.

Attributes:

Name Type Description
_components RegistryDict

Dictionary mapping component categories to their registered components.

_metadata dict[str, dict[str, dict[str, Any]]]

Dictionary storing metadata about registered components.

Methods:

Name Description
register

Register a component.

get

Get a registered component.

list

List registered components.

get_metadata

Get metadata for a registered component.

create

Create an instance of a registered component.

create_from_config

Create a component instance from a configuration dictionary.

__contains__

Check if a component is registered.

clear

Clear registered components.

Source code in spectrans/core/registry.py
def __init__(self) -> None:
    self._components: RegistryDict = {
        "transform": {},
        "mixing": {},
        "attention": {},
        "block": {},
        "model": {},
        "kernel": {},
        "operator": {},
    }

    # Store metadata about components
    self._metadata: dict[str, dict[str, dict[str, Any]]] = {
        category: {} for category in self._components
    }
Functions
register
register(category: ComponentType, name: str, component: ComponentClass, metadata: dict[str, Any] | None = None) -> None

Register a component.

Parameters:

Name Type Description Default
category ComponentType

Category of the component (e.g., 'transform', 'mixing').

required
name str

Name to register the component under.

required
component ComponentClass

The component class to register.

required
metadata dict[str, Any] | None

Optional metadata about the component.

None

Raises:

Type Description
ValueError

If the category is unknown or component name already exists.

Source code in spectrans/core/registry.py
def register(
    self,
    category: ComponentType,
    name: str,
    component: ComponentClass,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Register a component.

    Parameters
    ----------
    category : ComponentType
        Category of the component (e.g., 'transform', 'mixing').
    name : str
        Name to register the component under.
    component : ComponentClass
        The component class to register.
    metadata : dict[str, Any] | None, default=None
        Optional metadata about the component.

    Raises
    ------
    ValueError
        If the category is unknown or component name already exists.
    """
    if category not in self._components:
        raise ValueError(
            f"Unknown category: {category}. "
            f"Available categories: {list(self._components.keys())}"
        )

    if name in self._components[category]:
        raise ValueError(f"Component '{name}' already registered in category '{category}'")

    self._components[category][name] = component

    if metadata is not None:
        self._metadata[category][name] = metadata
get
get(category: ComponentType, name: str) -> ComponentClass

Get a registered component.

Parameters:

Name Type Description Default
category ComponentType

Category of the component.

required
name str

Name of the component.

required

Returns:

Type Description
ComponentClass

The registered component class.

Raises:

Type Description
ValueError

If the category or component name is not found.

Source code in spectrans/core/registry.py
def get(self, category: ComponentType, name: str) -> ComponentClass:
    """Get a registered component.

    Parameters
    ----------
    category : ComponentType
        Category of the component.
    name : str
        Name of the component.

    Returns
    -------
    ComponentClass
        The registered component class.

    Raises
    ------
    ValueError
        If the category or component name is not found.
    """
    if category not in self._components:
        raise ValueError(
            f"Unknown category: {category}. "
            f"Available categories: {list(self._components.keys())}"
        )

    if name not in self._components[category]:
        available = list(self._components[category].keys())
        raise ValueError(f"Unknown {category}: '{name}'. Available {category}s: {available}")

    return self._components[category][name]
list
list(category: ComponentType | None = None) -> list[str] | dict[str, list[str]]

List registered components.

Parameters:

Name Type Description Default
category ComponentType | None

Category to list components for. If None, lists all categories.

None

Returns:

Type Description
list[str] | dict[str, list[str]]

If category is specified, returns list of component names. Otherwise, returns dict mapping categories to component names.

Source code in spectrans/core/registry.py
def list(self, category: ComponentType | None = None) -> list[str] | dict[str, list[str]]:
    """List registered components.

    Parameters
    ----------
    category : ComponentType | None, default=None
        Category to list components for. If None, lists all categories.

    Returns
    -------
    list[str] | dict[str, list[str]]
        If category is specified, returns list of component names.
        Otherwise, returns dict mapping categories to component names.
    """
    if category is not None:
        if category not in self._components:
            raise ValueError(
                f"Unknown category: {category}. "
                f"Available categories: {list(self._components.keys())}"
            )
        return list(self._components[category].keys())

    return {cat: list(comps.keys()) for cat, comps in self._components.items()}
get_metadata
get_metadata(category: ComponentType, name: str) -> dict[str, Any] | None

Get metadata for a registered component.

Parameters:

Name Type Description Default
category ComponentType

Category of the component.

required
name str

Name of the component.

required

Returns:

Type Description
dict[str, Any] | None

Metadata dictionary if available, None otherwise.

Source code in spectrans/core/registry.py
def get_metadata(
    self,
    category: ComponentType,
    name: str,
) -> dict[str, Any] | None:
    """Get metadata for a registered component.

    Parameters
    ----------
    category : ComponentType
        Category of the component.
    name : str
        Name of the component.

    Returns
    -------
    dict[str, Any] | None
        Metadata dictionary if available, None otherwise.
    """
    return self._metadata.get(category, {}).get(name)
create
create(category: ComponentType, name: str, **kwargs: Any) -> Any

Create an instance of a registered component.

Parameters:

Name Type Description Default
category ComponentType

Category of the component.

required
name str

Name of the component.

required
**kwargs Any

Keyword arguments to pass to the component constructor.

{}

Returns:

Type Description
Any

Instance of the component.

Source code in spectrans/core/registry.py
def create(
    self,
    category: ComponentType,
    name: str,
    **kwargs: Any,
) -> Any:
    """Create an instance of a registered component.

    Parameters
    ----------
    category : ComponentType
        Category of the component.
    name : str
        Name of the component.
    **kwargs : Any
        Keyword arguments to pass to the component constructor.

    Returns
    -------
    Any
        Instance of the component.
    """
    component_class = self.get(category, name)
    return component_class(**kwargs)
create_from_config
create_from_config(category: ComponentType, config: ConfigDict) -> Any

Create a component instance from a configuration dictionary.

Parameters:

Name Type Description Default
category ComponentType

Category of the component.

required
config ConfigDict

Configuration dictionary with 'type' and optional 'params' keys.

required

Returns:

Type Description
Any

Instance of the component.

Raises:

Type Description
ValueError

If 'type' key is missing from config.

Source code in spectrans/core/registry.py
def create_from_config(
    self,
    category: ComponentType,
    config: ConfigDict,
) -> Any:
    """Create a component instance from a configuration dictionary.

    Parameters
    ----------
    category : ComponentType
        Category of the component.
    config : ConfigDict
        Configuration dictionary with 'type' and optional 'params' keys.

    Returns
    -------
    Any
        Instance of the component.

    Raises
    ------
    ValueError
        If 'type' key is missing from config.
    """
    if "type" not in config:
        raise ValueError("Configuration must contain 'type' key")

    name = config["type"]
    raw_params: (
        int
        | float
        | str
        | bool
        | list[int | float | str | bool]
        | dict[str, int | float | str | bool | list[int | float | str | bool]]
    ) = config.get("params", {})
    if isinstance(raw_params, dict):
        params: dict[str, int | float | str | bool | list[int | float | str | bool]] = (
            raw_params
        )
    else:
        params = {}

    return self.create(category, name, **params)  # type: ignore[arg-type]
__contains__
__contains__(item: tuple[ComponentType, str]) -> bool

Check if a component is registered.

Parameters:

Name Type Description Default
item tuple[ComponentType, str]

Tuple of (category, name) to check.

required

Returns:

Type Description
bool

True if the component is registered.

Source code in spectrans/core/registry.py
def __contains__(self, item: tuple[ComponentType, str]) -> bool:
    """Check if a component is registered.

    Parameters
    ----------
    item : tuple[ComponentType, str]
        Tuple of (category, name) to check.

    Returns
    -------
    bool
        True if the component is registered.
    """
    category, name = item
    return category in self._components and name in self._components[category]
clear
clear(category: ComponentType | None = None) -> None

Clear registered components.

Parameters:

Name Type Description Default
category ComponentType | None

Category to clear. If None, clears all categories.

None
Source code in spectrans/core/registry.py
def clear(self, category: ComponentType | None = None) -> None:
    """Clear registered components.

    Parameters
    ----------
    category : ComponentType | None, default=None
        Category to clear. If None, clears all categories.
    """
    if category is not None:
        if category not in self._components:
            raise ValueError(f"Unknown category: {category}")
        self._components[category].clear()
        self._metadata[category].clear()
    else:
        for cat in self._components:
            self._components[cat].clear()
            self._metadata[cat].clear()

Functions

create_component

create_component(category: ComponentType, name: str, **kwargs: Any) -> Any

Create an instance of a registered component.

Parameters:

Name Type Description Default
category ComponentType

Category of the component.

required
name str

Name of the component.

required
**kwargs Any

Keyword arguments for the component constructor.

{}

Returns:

Type Description
Any

Instance of the component.

Source code in spectrans/core/registry.py
def create_component(
    category: ComponentType,
    name: str,
    **kwargs: Any,
) -> Any:
    """Create an instance of a registered component.

    Parameters
    ----------
    category : ComponentType
        Category of the component.
    name : str
        Name of the component.
    **kwargs : Any
        Keyword arguments for the component constructor.

    Returns
    -------
    Any
        Instance of the component.
    """
    return registry.create(category, name, **kwargs)

get_component

get_component(category: ComponentType, name: str) -> ComponentClass

Get a registered component class.

Parameters:

Name Type Description Default
category ComponentType

Category of the component.

required
name str

Name of the component.

required

Returns:

Type Description
ComponentClass

The registered component class.

Source code in spectrans/core/registry.py
def get_component(category: ComponentType, name: str) -> ComponentClass:
    """Get a registered component class.

    Parameters
    ----------
    category : ComponentType
        Category of the component.
    name : str
        Name of the component.

    Returns
    -------
    ComponentClass
        The registered component class.
    """
    return registry.get(category, name)

list_components

list_components(category: ComponentType | None = None) -> list[str] | dict[str, list[str]]

List available components.

Parameters:

Name Type Description Default
category ComponentType | None

Category to list. If None, lists all categories.

None

Returns:

Type Description
list[str] | dict[str, list[str]]

Component names or dict of categories to names.

Source code in spectrans/core/registry.py
def list_components(
    category: ComponentType | None = None,
) -> list[str] | dict[str, list[str]]:
    """List available components.

    Parameters
    ----------
    category : ComponentType | None, default=None
        Category to list. If None, lists all categories.

    Returns
    -------
    list[str] | dict[str, list[str]]
        Component names or dict of categories to names.
    """
    return registry.list(category)

register_component

register_component(category: ComponentType, name: str, metadata: dict[str, Any] | None = None) -> Callable[[ComponentClass], ComponentClass]

Decorator for registering components.

Parameters:

Name Type Description Default
category ComponentType

Category to register the component under.

required
name str

Name to register the component as.

required
metadata dict[str, Any] | None

Optional metadata about the component.

None

Returns:

Type Description
Callable[[ComponentClass], ComponentClass]

Decorator function.

Examples:

>>> @register_component('transform', 'my_fft')
... class MyFFT(SpectralTransform):
...     pass
Source code in spectrans/core/registry.py
def register_component(
    category: ComponentType,
    name: str,
    metadata: dict[str, Any] | None = None,
) -> Callable[[ComponentClass], ComponentClass]:
    """Decorator for registering components.

    Parameters
    ----------
    category : ComponentType
        Category to register the component under.
    name : str
        Name to register the component as.
    metadata : dict[str, Any] | None, default=None
        Optional metadata about the component.

    Returns
    -------
    Callable[[ComponentClass], ComponentClass]
        Decorator function.

    Examples
    --------
    >>> @register_component('transform', 'my_fft')
    ... class MyFFT(SpectralTransform):
    ...     pass
    """

    def decorator(cls: ComponentClass) -> ComponentClass:
        registry.register(category, name, cls, metadata)
        return cls

    return decorator