Skip to content

Attention Layers

spectrans.layers.attention

Spectral attention layer implementations with linear complexity.

Provides attention mechanisms based on spectral methods and kernel approximations, achieving linear or logarithmic complexity compared to the quadratic complexity of standard attention. Implementations include Random Fourier Features, orthogonal transforms, and hybrid approaches.

Modules:

Name Description
lst

Linear Spectral Transform attention implementations.

spectral

Kernel-based spectral attention mechanisms.

Classes:

Name Description
DCTAttention

Specialized LST attention using discrete cosine transform.

HadamardAttention

Fast attention using Hadamard transform operations.

KernelAttention

General kernel-based attention with various kernel options.

LSTAttention

Linear Spectral Transform attention with configurable transforms.

MixedSpectralAttention

Multi-transform attention combining multiple spectral methods.

PerformerAttention

Performer-style attention with FAVOR+ algorithm.

SpectralAttention

Multi-head spectral attention using random Fourier features.

Examples:

Using spectral attention with RFF:

>>> import torch
>>> from spectrans.layers.attention import SpectralAttention
>>>
>>> attn = SpectralAttention(hidden_dim=512, num_heads=8, num_features=256)
>>> x = torch.randn(32, 100, 512)
>>> output = attn(x)
>>> assert output.shape == x.shape

Using LST attention with DCT:

>>> from spectrans.layers.attention import DCTAttention
>>>
>>> attn = DCTAttention(hidden_dim=512, num_heads=8)
>>> x = torch.randn(16, 128, 512)
>>> output = attn(x)

Using Performer attention:

>>> from spectrans.layers.attention import PerformerAttention
>>>
>>> attn = PerformerAttention(
...     hidden_dim=768,
...     num_heads=12,
...     num_features=256,
...     use_orthogonal=True
... )
>>> output = attn(x)
Notes

Complexity Analysis:

Standard attention requires \(O(n^2 d)\) time and \(O(n^2)\) memory. Spectral attention reduces this to \(O(n d k)\) time and \(O(n k)\) memory, where \(k\) is the number of random features. LST attention achieves \(O(n d \log n)\) time with \(O(n d)\) memory. Performer uses \(O(n d k)\) time with orthogonal features. Here \(n\) is sequence length and \(d\) is dimension.

Kernel approximation quality scales as \(O(1/\sqrt{k})\) for random features.

References

Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, and Adrian Weller. 2021. Rethinking attention with performers. In Proceedings of the International Conference on Learning Representations (ICLR).

Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. 2020. Transformers are RNNs: Fast autoregressive transformers with linear attention. In Proceedings of the 37th International Conference on Machine Learning (ICML), pages 5156-5165.

See Also

spectrans.kernels : Kernel functions used by attention mechanisms. spectrans.transforms : Spectral transforms used by LST attention. spectrans.layers : Parent module containing all layer implementations.

Classes

DCTAttention

DCTAttention(hidden_dim: int, num_heads: int = 8, dct_type: int = 2, learnable_scale: bool = True, dropout: float = 0.0)

Bases: LSTAttention

Attention using Discrete Cosine Transform.

Specialized LST attention that uses DCT for all heads for real-valued signals with energy compaction.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

8
dct_type int

DCT type (2 is most common).

2
learnable_scale bool

Whether to use learnable scaling.

True
dropout float

Dropout probability.

0.0
Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    dct_type: int = 2,
    learnable_scale: bool = True,
    dropout: float = 0.0,
):
    super().__init__(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        transform_type="dct",
        learnable_scale=learnable_scale,
        normalize=True,
        dropout=dropout,
    )

    self.dct_type = dct_type

    # Override transform with specific DCT type
    # Note: Current DCT implementation only supports type 2
    # Future versions may support other types
    if dct_type != 2:
        # For now, still use type 2 DCT
        pass

HadamardAttention

HadamardAttention(hidden_dim: int, num_heads: int = 8, scale_by_sqrt: bool = True, learnable_scale: bool = True, dropout: float = 0.0)

Bases: LSTAttention

Attention using fast Hadamard transform.

Uses Hadamard transform for \(O(n \log n)\) attention computation with binary coefficients.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

8
scale_by_sqrt bool

Whether to scale by sqrt(n) for orthogonality.

True
learnable_scale bool

Whether to use learnable diagonal scaling.

True
dropout float

Dropout probability.

0.0
Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    scale_by_sqrt: bool = True,
    learnable_scale: bool = True,
    dropout: float = 0.0,
):
    super().__init__(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        transform_type="hadamard",
        learnable_scale=learnable_scale,
        normalize=True,
        dropout=dropout,
    )

    self.scale_by_sqrt = scale_by_sqrt

    # Additional scaling for Hadamard
    if scale_by_sqrt:
        # Initialize scale with 1/sqrt(n) factor
        with torch.no_grad():
            self.scale.data = self.scale.data / math.sqrt(self.head_dim)

LSTAttention

LSTAttention(hidden_dim: int, num_heads: int = 8, transform_type: Literal['dct', 'dst', 'hadamard', 'mixed'] = 'dct', learnable_scale: bool = True, normalize: bool = True, dropout: float = 0.0, use_bias: bool = True)

Bases: AttentionLayer

Linear Spectral Transform attention mechanism.

Implements attention using orthogonal transforms (DCT, DST, Hadamard) with learnable diagonal scaling.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
num_heads int

Number of attention heads.

8
transform_type Literal['dct', 'dst', 'hadamard', 'mixed']

Type of transform to use. "mixed" uses different transforms per head.

"dct"
learnable_scale bool

Whether to use learnable diagonal scaling matrix.

True
normalize bool

Whether to normalize in transform domain.

True
dropout float

Dropout probability.

0.0
use_bias bool

Whether to use bias in projections.

True

Attributes:

Name Type Description
head_dim int

Dimension per attention head.

transform_type str

Type of transform being used.

transforms ModuleList

List of transforms (one per head if mixed).

scale Parameter | None

Learnable diagonal scaling if enabled.

Methods:

Name Description
forward

Forward pass of LST attention.

Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    transform_type: Literal["dct", "dst", "hadamard", "mixed"] = "dct",
    learnable_scale: bool = True,
    normalize: bool = True,
    dropout: float = 0.0,
    use_bias: bool = True,
):
    super().__init__(hidden_dim, num_heads, dropout)

    self.head_dim = hidden_dim // num_heads
    assert self.head_dim * num_heads == hidden_dim, (
        f"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
    )

    self.transform_type = transform_type
    self.normalize = normalize

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)

    # Initialize transforms
    self.transforms: nn.ModuleList = nn.ModuleList()  # Contains SpectralTransform objects
    if transform_type == "mixed":
        # Use different transforms for different heads
        transform_types = ["dct", "dst", "hadamard"]
        for i in range(num_heads):
            t_type = transform_types[i % len(transform_types)]
            self.transforms.append(self._create_transform(t_type))
    else:
        # Use same transform for all heads
        transform = self._create_transform(transform_type)
        for _ in range(num_heads):
            self.transforms.append(transform)

    # Learnable diagonal scaling
    if learnable_scale:
        # Different scale per head and position
        self.scale = nn.Parameter(torch.ones(num_heads, 1, self.head_dim))
    else:
        self.register_buffer("scale", torch.ones(num_heads, 1, self.head_dim))
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of LST attention.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask of shape (batch_size, seq_len).

None
return_attention bool

Whether to return attention weights (not supported).

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, returns (output, None).

Source code in spectrans/layers/attention/lst.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of LST attention.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask of shape (batch_size, seq_len).
    return_attention : bool, default=False
        Whether to return attention weights (not supported).

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output tensor of shape (batch_size, seq_len, hidden_dim).
        If return_attention=True, returns (output, None).
    """
    batch_size, seq_len, _ = x.shape

    # Linear projections
    Q = self.q_proj(x)
    K = self.k_proj(x)
    V = self.v_proj(x)

    # Reshape for multi-head attention
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # Transpose to (batch, heads, seq_len, head_dim)
    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    # Apply transforms per head
    outputs = []
    for head_idx in range(self.num_heads):
        q_head = Q[:, head_idx]  # (batch, seq_len, head_dim)
        k_head = K[:, head_idx]
        v_head = V[:, head_idx]

        # Get transform for this head
        transform: SpectralTransform
        if self.transform_type == "mixed":
            transform = self.transforms[head_idx]  # type: ignore[assignment]
        else:
            transform = self.transforms[0]  # type: ignore[assignment]

        # Apply transform along sequence dimension
        q_transformed = transform.transform(q_head, dim=-2)
        k_transformed = transform.transform(k_head, dim=-2)
        v_transformed = transform.transform(v_head, dim=-2)

        # Apply mask in transform domain if provided
        if mask is not None:
            # Transform mask to frequency domain
            mask_float = mask.float().unsqueeze(-1)  # (batch, seq_len, 1)
            mask_transformed = transform.transform(mask_float, dim=-2)
            k_transformed = k_transformed * mask_transformed
            v_transformed = v_transformed * mask_transformed

        # Element-wise multiplication in transform domain
        # This replaces the QK^T computation
        attention_transformed = q_transformed * k_transformed * self.scale[head_idx]

        # Apply to values
        output_transformed = attention_transformed * v_transformed

        # Normalize if requested
        if self.normalize:
            # Compute normalization factor
            norm_factor = torch.abs(attention_transformed).sum(dim=-1, keepdim=True) + 1e-6
            output_transformed = output_transformed / norm_factor

        # Inverse transform
        output_head = transform.inverse_transform(output_transformed, dim=-2)

        # Real part for numerical stability
        if torch.is_complex(output_head):
            output_head = output_head.real

        outputs.append(output_head.unsqueeze(1))

    # Concatenate heads
    out = torch.cat(outputs, dim=1)  # (batch, heads, seq_len, head_dim)

    # Reshape
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)

    # Output projection and dropout
    out = self.out_proj(out)
    out = self.dropout(out)

    output: Tensor = out
    if return_attention:
        # Attention weights not available in LST
        return output, None  # type: ignore[return-value]
    return output

MixedSpectralAttention

MixedSpectralAttention(hidden_dim: int, num_heads: int = 9, use_fft: bool = True, use_dct: bool = True, use_hadamard: bool = True, dropout: float = 0.0)

Bases: AttentionLayer

Mixed spectral attention using multiple transform types.

Combines different spectral transforms across heads for diverse frequency representations.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads (should be divisible by 3 for even split).

8
use_fft bool

Whether to include FFT heads.

True
use_dct bool

Whether to include DCT heads.

True
use_hadamard bool

Whether to include Hadamard heads.

True
dropout float

Dropout probability.

0.0

Methods:

Name Description
forward

Forward pass of mixed spectral attention.

Source code in spectrans/layers/attention/lst.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 9,  # Divisible by 3
    use_fft: bool = True,
    use_dct: bool = True,
    use_hadamard: bool = True,
    dropout: float = 0.0,
):
    super().__init__(hidden_dim, num_heads, dropout)

    self.head_dim = hidden_dim // num_heads

    # Count enabled transform types
    enabled_transforms = []
    if use_fft:
        enabled_transforms.append("fft")
    if use_dct:
        enabled_transforms.append("dct")
    if use_hadamard:
        enabled_transforms.append("hadamard")

    if not enabled_transforms:
        raise ValueError("At least one transform type must be enabled")

    self.enabled_transforms = enabled_transforms

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    # Assign transforms to heads
    self.head_transforms = []
    for i in range(num_heads):
        transform_type = enabled_transforms[i % len(enabled_transforms)]
        self.head_transforms.append(transform_type)

    # Create transform modules
    from ...transforms import FFT1D

    self.fft = FFT1D() if use_fft else None
    self.dct = DCT(normalized=True) if use_dct else None
    self.hadamard = HadamardTransform(normalized=True) if use_hadamard else None

    # Learnable scales per transform type
    self.scales = nn.ParameterDict(
        {t: nn.Parameter(torch.ones(1, 1, self.head_dim)) for t in enabled_transforms}
    )
Functions
forward
forward(x: Tensor, _mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of mixed spectral attention.

Parameters:

Name Type Description Default
x Tensor

Input of shape (batch_size, seq_len, hidden_dim).

required
_mask Tensor | None

Attention mask (not implemented for spectral attention).

None
return_attention bool

Whether to return attention weights.

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output and optionally None for weights.

Source code in spectrans/layers/attention/lst.py
def forward(
    self,
    x: Tensor,
    _mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of mixed spectral attention.

    Parameters
    ----------
    x : Tensor
        Input of shape (batch_size, seq_len, hidden_dim).
    _mask : Tensor | None, default=None
        Attention mask (not implemented for spectral attention).
    return_attention : bool, default=False
        Whether to return attention weights.

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output and optionally None for weights.
    """
    batch_size, seq_len, _ = x.shape

    # Projections
    Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    # Process each head with its assigned transform
    outputs = []
    for head_idx in range(self.num_heads):
        transform_type = self.head_transforms[head_idx]
        scale = self.scales[transform_type]

        q_head = Q[:, head_idx]
        k_head = K[:, head_idx]
        v_head = V[:, head_idx]

        # Apply appropriate transform
        if transform_type == "fft":
            if self.fft is None:
                raise RuntimeError("FFT transform not initialized")
            q_t = self.fft.transform(q_head, dim=-2)
            k_t = self.fft.transform(k_head, dim=-2)
            v_t = self.fft.transform(v_head, dim=-2)

            # Complex multiplication in frequency domain
            attn_t = q_t * k_t.conj() * scale
            out_t = attn_t * v_t

            # Inverse transform
            out_head = self.fft.inverse_transform(out_t, dim=-2).real

        elif transform_type == "dct":
            if self.dct is None:
                raise RuntimeError("DCT transform not initialized")
            q_t = self.dct.transform(q_head, dim=-2)
            k_t = self.dct.transform(k_head, dim=-2)
            v_t = self.dct.transform(v_head, dim=-2)

            attn_t = q_t * k_t * scale
            out_t = attn_t * v_t

            out_head = self.dct.inverse_transform(out_t, dim=-2)

        else:  # hadamard
            if self.hadamard is None:
                raise RuntimeError("Hadamard transform not initialized")
            q_t = self.hadamard.transform(q_head, dim=-2)
            k_t = self.hadamard.transform(k_head, dim=-2)
            v_t = self.hadamard.transform(v_head, dim=-2)

            attn_t = q_t * k_t * scale
            out_t = attn_t * v_t

            out_head = self.hadamard.inverse_transform(out_t, dim=-2)

        outputs.append(out_head.unsqueeze(1))

    # Concatenate and reshape
    out = torch.cat(outputs, dim=1)
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)

    # Output projection
    out = self.out_proj(out)
    out = self.dropout(out)

    output: Tensor = out
    if return_attention:
        return output, None  # type: ignore[return-value]
    return output

KernelAttention

KernelAttention(hidden_dim: int, num_heads: int = 8, kernel_type: Literal['gaussian', 'polynomial', 'spectral'] = 'gaussian', rank: int | None = None, num_features: int | None = None, dropout: float = 0.0)

Bases: AttentionLayer

General kernel-based attention with various kernel options.

Supports multiple kernel types including Gaussian, polynomial, and learnable spectral kernels.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of heads.

8
kernel_type Literal['gaussian', 'polynomial', 'spectral']

Type of kernel to use.

"gaussian"
rank int | None

Rank for low-rank approximations.

None
num_features int | None

Number of features for RFF kernels.

None
dropout float

Dropout probability.

0.0

Attributes:

Name Type Description
kernel_type str

Type of kernel being used.

rank int | None

Rank for approximations.

Methods:

Name Description
forward

Forward pass of kernel attention.

Source code in spectrans/layers/attention/spectral.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    kernel_type: Literal["gaussian", "polynomial", "spectral"] = "gaussian",
    rank: int | None = None,
    num_features: int | None = None,
    dropout: float = 0.0,
):
    super().__init__(hidden_dim, num_heads, dropout)

    self.head_dim = hidden_dim // num_heads
    self.kernel_type = kernel_type
    self.rank = rank or min(64, self.head_dim)

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    # Initialize kernel - using union type to handle different kernel types
    if kernel_type == "gaussian":
        from ...kernels import GaussianRFFKernel

        self.kernel = GaussianRFFKernel(
            input_dim=self.head_dim,
            num_features=num_features or self.head_dim,
            sigma=1.0 / math.sqrt(self.head_dim),
        )
        self.use_features = True

    elif kernel_type == "polynomial":
        from ...kernels import PolynomialSpectralKernel

        self.kernel = PolynomialSpectralKernel(
            rank=self.rank,
            degree=2,
        )
        self.use_features = False

    else:  # spectral
        from ...kernels import LearnableSpectralKernel

        self.kernel = LearnableSpectralKernel(
            input_dim=self.head_dim,
            rank=self.rank,
            trainable_eigenvectors=True,
        )
        self.use_features = True
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of kernel attention.

Parameters:

Name Type Description Default
x Tensor

Input of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask.

None
return_attention bool

Whether to return attention weights.

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output and optionally attention weights.

Source code in spectrans/layers/attention/spectral.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of kernel attention.

    Parameters
    ----------
    x : Tensor
        Input of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask.
    return_attention : bool, default=False
        Whether to return attention weights.

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output and optionally attention weights.
    """
    batch_size, seq_len, _ = x.shape

    # Projections and reshape
    Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    if self.use_features:
        # Use feature-based approximation - kernel should be a callable (RandomFeatureMap)
        if hasattr(self.kernel, "extract_features"):
            Q_feat = self.kernel.extract_features(Q)  # type: ignore[operator]
            K_feat = self.kernel.extract_features(K)  # type: ignore[operator]
        else:
            # Call the kernel as a function
            Q_feat = self.kernel(Q)  # type: ignore[operator]
            K_feat = self.kernel(K)  # type: ignore[operator]

        if mask is not None:
            mask_exp = mask.unsqueeze(1).unsqueeze(-1)
            K_feat = K_feat.masked_fill(~mask_exp, 0)
            V = V.masked_fill(~mask_exp, 0)

        # Linear attention computation
        KV = torch.matmul(K_feat.transpose(-2, -1), V)
        out: Tensor = torch.matmul(Q_feat, KV)

        # Normalize
        K_sum = K_feat.sum(dim=-2, keepdim=True).transpose(-2, -1)
        normalizer = torch.matmul(Q_feat, K_sum) + 1e-6
        out = out / normalizer

        attention_weights: Tensor | None = None

    else:
        # Direct kernel computation (for small sequences)
        # Flatten heads and batch for kernel computation
        Q_flat = Q.reshape(-1, seq_len, self.head_dim)
        K_flat = K.reshape(-1, seq_len, self.head_dim)

        # Compute kernel matrix
        attention_weights = self.kernel.compute(Q_flat, K_flat)  # type: ignore[operator]
        attention_weights = attention_weights.view(batch_size, self.num_heads, seq_len, seq_len)

        if mask is not None:
            mask_exp = mask.unsqueeze(1).unsqueeze(2)
            attention_weights = attention_weights.masked_fill(~mask_exp, -1e9)

        # Normalize
        attention_weights = F.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply to values
        out = torch.matmul(attention_weights, V)

    # Reshape output
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)
    out = self.out_proj(out)
    out = self.dropout(out)

    if return_attention:
        return out, attention_weights  # type: ignore[return-value]
    return out

PerformerAttention

PerformerAttention(hidden_dim: int, num_heads: int = 8, num_features: int | None = None, generalized: bool = False, dropout: float = 0.0)

Bases: SpectralAttention

Performer-style attention with FAVOR+ algorithm.

Implements the Performer architecture with positive orthogonal random features (FAVOR+) for softmax kernel approximation.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension.

required
num_heads int

Number of attention heads.

8
num_features int | None

Number of random features.

None
generalized bool

Whether to use generalized attention (without softmax).

False
dropout float

Dropout probability.

0.0

Attributes:

Name Type Description
generalized bool

Whether using generalized attention.

Methods:

Name Description
forward

Forward pass of Performer attention.

Source code in spectrans/layers/attention/spectral.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    num_features: int | None = None,
    generalized: bool = False,
    dropout: float = 0.0,
):
    super().__init__(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        num_features=num_features,
        kernel_type="softmax",
        use_orthogonal=True,
        feature_redraw=False,
        dropout=dropout,
    )

    self.generalized = generalized

    if generalized:
        # For generalized attention, use different kernel
        self.kernel = RFFAttentionKernel(
            input_dim=self.head_dim,
            num_features=self.num_features,
            kernel_type="relu",
            use_orthogonal=True,
        )
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of Performer attention.

Parameters:

Name Type Description Default
x Tensor

Input of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask.

None
return_attention bool

Whether to return attention weights.

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output tensor and optionally None for weights.

Source code in spectrans/layers/attention/spectral.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of Performer attention.

    Parameters
    ----------
    x : Tensor
        Input of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask.
    return_attention : bool, default=False
        Whether to return attention weights.

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output tensor and optionally None for weights.
    """
    if self.generalized:
        # Generalized attention without exponential
        return self._generalized_attention(x, mask)
    else:
        # Standard Performer with softmax approximation
        return super().forward(x, mask, return_attention)

SpectralAttention

SpectralAttention(hidden_dim: int, num_heads: int = 8, num_features: int | None = None, head_dim: int | None = None, kernel_type: Literal['gaussian', 'softmax'] = 'softmax', use_orthogonal: bool = True, feature_redraw: bool = False, dropout: float = 0.0, use_bias: bool = True)

Bases: AttentionLayer

Multi-head spectral attention using RFF approximation.

Implements attention using Random Fourier Features to approximate the softmax kernel with linear complexity.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension of the model.

required
num_heads int

Number of attention heads.

8
num_features int | None

Number of random features. If None, uses hidden_dim.

None
head_dim int | None

Dimension per head. If None, uses hidden_dim // num_heads.

None
kernel_type Literal['gaussian', 'softmax']

Type of kernel to approximate.

"softmax"
use_orthogonal bool

Whether to use orthogonal random features.

True
feature_redraw bool

Whether to redraw features at each forward pass.

False
dropout float

Dropout probability.

0.0
use_bias bool

Whether to use bias in projections.

True

Attributes:

Name Type Description
head_dim int

Dimension per attention head.

num_features int

Number of random features used.

q_proj Linear

Query projection.

k_proj Linear

Key projection.

v_proj Linear

Value projection.

out_proj Linear

Output projection.

kernel RandomFeatureMap | KernelFunction

Kernel for attention approximation.

Methods:

Name Description
forward

Forward pass of spectral attention.

Source code in spectrans/layers/attention/spectral.py
def __init__(
    self,
    hidden_dim: int,
    num_heads: int = 8,
    num_features: int | None = None,
    head_dim: int | None = None,
    kernel_type: Literal["gaussian", "softmax"] = "softmax",
    use_orthogonal: bool = True,
    feature_redraw: bool = False,
    dropout: float = 0.0,
    use_bias: bool = True,
):
    super().__init__(hidden_dim, num_heads, dropout)

    # Determine head dimension
    self.head_dim = head_dim or (hidden_dim // num_heads)
    assert self.head_dim * num_heads == hidden_dim, (
        f"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
    )

    # Number of random features
    self.num_features = num_features or self.head_dim

    # Projections
    self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)
    self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias)

    # Kernel for approximation
    if kernel_type == "softmax":
        self.kernel = RFFAttentionKernel(
            input_dim=self.head_dim,
            num_features=self.num_features,
            kernel_type="softmax",
            use_orthogonal=use_orthogonal,
            redraw=feature_redraw,
        )
    else:  # gaussian
        self.kernel = GaussianRFFKernel(
            input_dim=self.head_dim,
            num_features=self.num_features,
            sigma=1.0 / math.sqrt(self.head_dim),
            orthogonal=use_orthogonal,
        )

    # Normalization
    self.scale = 1.0 / math.sqrt(self.num_features)
Functions
forward
forward(x: Tensor, mask: Tensor | None = None, return_attention: bool = False) -> Tensor | tuple[Tensor, ...]

Forward pass of spectral attention.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, seq_len, hidden_dim).

required
mask Tensor | None

Attention mask of shape (batch_size, seq_len).

None
return_attention bool

Whether to return attention weights (not supported).

False

Returns:

Type Description
Tensor or tuple[Tensor, Tensor]

Output tensor of shape (batch_size, seq_len, hidden_dim). If return_attention=True, also returns None (weights not available).

Source code in spectrans/layers/attention/spectral.py
def forward(
    self,
    x: Tensor,
    mask: Tensor | None = None,
    return_attention: bool = False,
) -> Tensor | tuple[Tensor, ...]:
    """Forward pass of spectral attention.

    Parameters
    ----------
    x : Tensor
        Input tensor of shape (batch_size, seq_len, hidden_dim).
    mask : Tensor | None, default=None
        Attention mask of shape (batch_size, seq_len).
    return_attention : bool, default=False
        Whether to return attention weights (not supported).

    Returns
    -------
    Tensor or tuple[Tensor, Tensor]
        Output tensor of shape (batch_size, seq_len, hidden_dim).
        If return_attention=True, also returns None (weights not available).
    """
    batch_size, seq_len, _ = x.shape

    # Linear projections
    Q = self.q_proj(x)
    K = self.k_proj(x)
    V = self.v_proj(x)

    # Reshape for multi-head attention
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # Transpose to (batch, heads, seq_len, head_dim)
    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)

    # Apply kernel feature maps
    Q_features = self.kernel(Q)  # (batch, heads, seq_len, num_features)
    K_features = self.kernel(K)  # (batch, heads, seq_len, num_features)

    # Apply mask if provided
    if mask is not None:
        # Expand mask for heads dimension
        mask = mask.unsqueeze(1).unsqueeze(-1)  # (batch, 1, seq_len, 1)
        K_features = K_features.masked_fill(~mask, 0)
        V = V.masked_fill(~mask, 0)

    # Compute KV (batch, heads, num_features, head_dim)
    KV = torch.matmul(K_features.transpose(-2, -1), V)

    # Compute QKV (batch, heads, seq_len, head_dim)
    out: Tensor = torch.matmul(Q_features, KV)

    # Normalize
    # Compute normalizer: Q_features @ (K_features^T @ 1)
    K_sum = K_features.sum(dim=-2, keepdim=True).transpose(-2, -1)
    normalizer = torch.matmul(Q_features, K_sum) + 1e-6
    out = out / normalizer

    # Transpose back and reshape
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, self.hidden_dim)

    # Output projection and dropout
    out = self.out_proj(out)
    out = self.dropout(out)

    if return_attention:
        # Attention weights not directly available in linear attention
        return out, None  # type: ignore[return-value]
    return out