Skip to content

Utility Functions

spectrans.utils

Utility functions for spectral transformer implementations.

This module provides utility functions for spectral neural networks, including specialized complex number operations, initialization schemes for spectral parameters, and padding utilities for signal processing operations.

These utilities are designed to support the mathematical rigor and numerical stability required for spectral transformer architectures while providing convenient abstractions for common operations.

Modules:

Name Description
complex

Complex tensor operations and utilities.

initialization

Parameter initialization schemes for spectral networks.

padding

Padding utilities for signal processing.

Functions:

Name Description
complex_conjugate

Compute complex conjugate with proper error handling.

complex_multiply

Element-wise complex multiplication with broadcasting.

complex_divide

Complex division with zero-division safety checks.

complex_modulus

Compute magnitude of complex tensors.

complex_phase

Extract phase angles from complex tensors.

complex_polar

Construct complex tensors from polar coordinates.

complex_exp

Complex exponential function.

complex_log

Complex logarithm with numerical safety.

complex_relu

ReLU activation applied to both real and imaginary parts.

complex_dropout

Dropout preserving phase relationships.

make_complex

Construct complex tensors from real/imaginary parts.

split_complex

Split complex tensors into real/imaginary components.

spectral_init

Initialize parameters for spectral neural networks.

frequency_init

Initialize parameters with frequency-domain properties.

orthogonal_spectral_init

Initialize with orthogonality constraints.

complex_xavier_init

Xavier initialization for complex-valued parameters.

complex_kaiming_init

Kaiming initialization for complex parameters.

pad_to_power_of_2

Pad tensor to next power of 2 for efficient FFT.

pad_for_fft

Pad tensor for FFT operations.

circular_pad

Apply circular (periodic) padding.

reflect_pad

Apply reflection padding for boundary handling.

Examples:

Complex number operations:

>>> import torch
>>> from spectrans.utils import complex_multiply, complex_polar, split_complex
>>> # Create complex tensors
>>> z1 = torch.complex(torch.randn(10), torch.randn(10))
>>> z2 = torch.complex(torch.randn(10), torch.randn(10))
>>> product = complex_multiply(z1, z2)
>>>
>>> # Convert to polar form
>>> magnitude = torch.abs(z1)
>>> phase = torch.angle(z1)
>>> z1_reconstructed = complex_polar(magnitude, phase)

Spectral parameter initialization:

>>> from spectrans.utils import spectral_init, complex_xavier_init
>>> import torch.nn as nn
>>> # Initialize a linear layer for spectral transforms
>>> linear = nn.Linear(512, 512)
>>> spectral_init(linear.weight, method='frequency')
>>>
>>> # Initialize complex-valued parameters
>>> complex_params = torch.empty(256, 256, dtype=torch.complex64)
>>> complex_xavier_init(complex_params)

Padding for spectral operations:

>>> from spectrans.utils import pad_to_power_of_2, pad_for_fft
>>> signal = torch.randn(32, 500)  # 500 is not power of 2
>>> padded = pad_to_power_of_2(signal, dim=-1)  # Pads to 512
>>>
>>> # Pad to specific FFT length
>>> fft_ready = pad_for_fft(signal, target_length=1024, dim=-1)
Notes

Design Philosophy:

The utility functions follow these principles:

  1. Mathematical Safety: All operations include proper error checking and handle edge cases (zeros, infinities, etc.)

  2. Numerical Stability: Implementations prioritize numerical stability over raw performance where trade-offs exist

  3. Type Safety: Type checking and clear error messages for incorrect usage patterns

  4. Gradient Compatibility: All operations support automatic differentiation for end-to-end neural network training

  5. Broadcasting Support: Operations follow PyTorch broadcasting conventions for flexible tensor manipulation

Complex Number Operations:

The complex utilities provide a consistent interface for complex tensor operations with proper error handling and mathematical safety. While many wrap existing PyTorch functions, they add domain-specific validation and optimization for spectral neural networks.

Initialization Schemes:

Spectral neural networks often require specialized parameter initialization due to: - Different scaling properties of spectral transforms - Complex-valued parameters requiring magnitude/phase initialization - Orthogonality constraints for certain spectral methods - Frequency-domain parameter interpretation

Padding Utilities:

Signal processing operations often require specific padding strategies: - Power-of-2 lengths for efficient FFT computation - Circular padding for periodic signal assumptions - Reflection padding for boundary effect minimization - Zero padding with proper unpadding for shape restoration

Performance Considerations:

  • All utilities are optimized for batch operations
  • GPU acceleration through native PyTorch operations
  • Memory efficiency with in-place operations where safe
  • Vectorized implementations for throughput
See Also

spectrans.utils.complex : Complex tensor operations spectrans.utils.initialization : Parameter initialization schemes spectrans.utils.padding : Padding utilities for signal processing spectrans.transforms : Spectral transforms using these utilities

Functions

complex_conjugate

complex_conjugate(x: Tensor) -> Tensor

Compute complex conjugate of input tensor.

Essential operation for spectral transforms, particularly for ensuring Hermitian symmetry in frequency domain operations.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required

Returns:

Type Description
Tensor

Complex conjugate tensor.

Raises:

Type Description
TypeError

If input is not a complex tensor.

Source code in spectrans/utils/complex.py
def complex_conjugate(x: Tensor) -> Tensor:
    """Compute complex conjugate of input tensor.

    Essential operation for spectral transforms, particularly for ensuring
    Hermitian symmetry in frequency domain operations.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.

    Returns
    -------
    Tensor
        Complex conjugate tensor.

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    return torch.conj(x)

complex_divide

complex_divide(a: Tensor, b: Tensor) -> Tensor

Divide two complex tensors element-wise.

Essential for spectral filtering operations. Includes safety checks for division by zero, which can occur in spectral nulls.

Parameters:

Name Type Description Default
a Tensor

Numerator complex tensor.

required
b Tensor

Denominator complex tensor.

required

Returns:

Type Description
Tensor

Complex division result.

Raises:

Type Description
TypeError

If inputs are not complex tensors.

ValueError

If denominator contains zeros.

RuntimeError

If tensors cannot be broadcast together.

Source code in spectrans/utils/complex.py
def complex_divide(a: Tensor, b: Tensor) -> Tensor:
    """Divide two complex tensors element-wise.

    Essential for spectral filtering operations. Includes safety checks
    for division by zero, which can occur in spectral nulls.

    Parameters
    ----------
    a : Tensor
        Numerator complex tensor.
    b : Tensor
        Denominator complex tensor.

    Returns
    -------
    Tensor
        Complex division result.

    Raises
    ------
    TypeError
        If inputs are not complex tensors.
    ValueError
        If denominator contains zeros.
    RuntimeError
        If tensors cannot be broadcast together.
    """
    if not a.is_complex():
        raise TypeError(f"Numerator must be complex tensor, got {a.dtype}")
    if not b.is_complex():
        raise TypeError(f"Denominator must be complex tensor, got {b.dtype}")

    # Check for zeros in denominator
    if torch.any(torch.abs(b) == 0):
        raise ValueError("Division by zero in denominator")

    try:
        return torch.div(a, b)
    except RuntimeError as e:
        raise RuntimeError(f"Cannot broadcast tensors with shapes {a.shape} and {b.shape}") from e

complex_dropout

complex_dropout(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor

Apply dropout to complex tensor.

Applies dropout to magnitude while preserving phase relationships. This is superior to independent real/imaginary dropout for spectral data.

This specialized dropout maintains the complex structure essential for spectral transformations while providing regularization.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required
p float

Dropout probability.

0.5
training bool

Whether in training mode.

True

Returns:

Type Description
Tensor

Complex tensor with dropout applied.

Raises:

Type Description
TypeError

If input is not a complex tensor.

ValueError

If dropout probability is not in [0, 1].

Source code in spectrans/utils/complex.py
def complex_dropout(x: Tensor, p: float = 0.5, training: bool = True) -> Tensor:
    """Apply dropout to complex tensor.

    Applies dropout to magnitude while preserving phase relationships.
    This is superior to independent real/imaginary dropout for spectral data.

    This specialized dropout maintains the complex structure essential for
    spectral transformations while providing regularization.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.
    p : float, default=0.5
        Dropout probability.
    training : bool, default=True
        Whether in training mode.

    Returns
    -------
    Tensor
        Complex tensor with dropout applied.

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    ValueError
        If dropout probability is not in [0, 1].
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    if not 0.0 <= p <= 1.0:
        raise ValueError(f"Dropout probability must be in [0, 1], got {p}")

    if not training or p == 0.0:
        return x

    # Create dropout mask for the magnitude
    # This preserves phase relationships better than independent dropout
    magnitude = torch.abs(x)
    phase = torch.angle(x)

    # Apply dropout to magnitude only
    dropped_magnitude = torch.nn.functional.dropout(magnitude, p=p, training=training)

    # Reconstruct complex tensor with same phases
    return torch.polar(dropped_magnitude, phase)

complex_exp

complex_exp(x: Tensor) -> Tensor

Compute complex exponential e^x.

Core operation for Fourier transforms and oscillatory functions. Accepts both real and complex inputs for flexibility.

Parameters:

Name Type Description Default
x Tensor

Input tensor (can be real or complex).

required

Returns:

Type Description
Tensor

Complex exponential tensor.

Source code in spectrans/utils/complex.py
def complex_exp(x: Tensor) -> Tensor:
    """Compute complex exponential e^x.

    Core operation for Fourier transforms and oscillatory functions.
    Accepts both real and complex inputs for flexibility.

    Parameters
    ----------
    x : Tensor
        Input tensor (can be real or complex).

    Returns
    -------
    Tensor
        Complex exponential tensor.
    """
    return torch.exp(x)

complex_log

complex_log(x: Tensor) -> Tensor

Compute complex natural logarithm.

Used in spectral domain operations and inverse transforms. Includes safety check for zeros where logarithm is undefined.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required

Returns:

Type Description
Tensor

Complex logarithm tensor.

Raises:

Type Description
TypeError

If input is not a complex tensor.

ValueError

If input contains zeros (logarithm undefined).

Source code in spectrans/utils/complex.py
def complex_log(x: Tensor) -> Tensor:
    """Compute complex natural logarithm.

    Used in spectral domain operations and inverse transforms.
    Includes safety check for zeros where logarithm is undefined.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.

    Returns
    -------
    Tensor
        Complex logarithm tensor.

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    ValueError
        If input contains zeros (logarithm undefined).
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    # Check for zeros where log is undefined
    if torch.any(torch.abs(x) == 0):
        raise ValueError("Logarithm undefined for zero values")

    return torch.log(x)

complex_modulus

complex_modulus(x: Tensor) -> Tensor

Compute magnitude (absolute value) of complex tensor.

Critical for spectral analysis where magnitude represents signal energy.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required

Returns:

Type Description
Tensor

Real tensor containing magnitudes.

Raises:

Type Description
TypeError

If input is not a complex tensor.

Source code in spectrans/utils/complex.py
def complex_modulus(x: Tensor) -> Tensor:
    """Compute magnitude (absolute value) of complex tensor.

    Critical for spectral analysis where magnitude represents signal energy.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.

    Returns
    -------
    Tensor
        Real tensor containing magnitudes.

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    return torch.abs(x)

complex_multiply

complex_multiply(a: Tensor, b: Tensor) -> Tensor

Multiply two complex tensors element-wise.

Performs (a_real + ia_imag) * (b_real + ib_imag) efficiently. Supports broadcasting according to PyTorch broadcasting rules.

Parameters:

Name Type Description Default
a Tensor

First complex tensor.

required
b Tensor

Second complex tensor.

required

Returns:

Type Description
Tensor

Complex product tensor.

Raises:

Type Description
TypeError

If inputs are not complex tensors.

RuntimeError

If tensors cannot be broadcast together.

Source code in spectrans/utils/complex.py
def complex_multiply(a: Tensor, b: Tensor) -> Tensor:
    """Multiply two complex tensors element-wise.

    Performs (a_real + i*a_imag) * (b_real + i*b_imag) efficiently.
    Supports broadcasting according to PyTorch broadcasting rules.


    Parameters
    ----------
    a : Tensor
        First complex tensor.
    b : Tensor
        Second complex tensor.

    Returns
    -------
    Tensor
        Complex product tensor.

    Raises
    ------
    TypeError
        If inputs are not complex tensors.
    RuntimeError
        If tensors cannot be broadcast together.
    """
    if not a.is_complex():
        raise TypeError(f"First argument must be complex tensor, got {a.dtype}")
    if not b.is_complex():
        raise TypeError(f"Second argument must be complex tensor, got {b.dtype}")

    try:
        return torch.mul(a, b)
    except RuntimeError as e:
        raise RuntimeError(f"Cannot broadcast tensors with shapes {a.shape} and {b.shape}") from e

complex_phase

complex_phase(x: Tensor) -> Tensor

Compute phase angle of complex tensor.

Phase information is crucial for spectral transformations and filter design.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required

Returns:

Type Description
Tensor

Real tensor containing phase angles in radians [-π, π].

Raises:

Type Description
TypeError

If input is not a complex tensor.

Source code in spectrans/utils/complex.py
def complex_phase(x: Tensor) -> Tensor:
    """Compute phase angle of complex tensor.

    Phase information is crucial for spectral transformations and filter design.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.

    Returns
    -------
    Tensor
        Real tensor containing phase angles in radians [-π, π].

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    return torch.angle(x)

complex_polar

complex_polar(magnitude: Tensor, phase: Tensor) -> Tensor

Construct complex tensor from magnitude and phase.

Fundamental for spectral operations where separate magnitude and phase processing is required. Includes validation for non-negative magnitudes.

Parameters:

Name Type Description Default
magnitude Tensor

Real tensor containing magnitudes (must be non-negative).

required
phase Tensor

Real tensor containing phase angles in radians.

required

Returns:

Type Description
Tensor

Complex tensor constructed from polar coordinates.

Raises:

Type Description
TypeError

If inputs are not real tensors.

ValueError

If magnitude contains negative values.

RuntimeError

If tensors cannot be broadcast together.

Source code in spectrans/utils/complex.py
def complex_polar(magnitude: Tensor, phase: Tensor) -> Tensor:
    """Construct complex tensor from magnitude and phase.

    Fundamental for spectral operations where separate magnitude and phase
    processing is required. Includes validation for non-negative magnitudes.

    Parameters
    ----------
    magnitude : Tensor
        Real tensor containing magnitudes (must be non-negative).
    phase : Tensor
        Real tensor containing phase angles in radians.

    Returns
    -------
    Tensor
        Complex tensor constructed from polar coordinates.

    Raises
    ------
    TypeError
        If inputs are not real tensors.
    ValueError
        If magnitude contains negative values.
    RuntimeError
        If tensors cannot be broadcast together.
    """
    if magnitude.is_complex():
        raise TypeError(f"Magnitude must be real tensor, got {magnitude.dtype}")
    if phase.is_complex():
        raise TypeError(f"Phase must be real tensor, got {phase.dtype}")

    if torch.any(magnitude < 0):
        raise ValueError("Magnitude must be non-negative")

    try:
        return torch.polar(magnitude, phase)
    except RuntimeError as e:
        raise RuntimeError(
            f"Cannot broadcast tensors with shapes {magnitude.shape} and {phase.shape}"
        ) from e

complex_relu

complex_relu(x: Tensor) -> Tensor

Apply ReLU activation to complex tensor.

Applies ReLU to both real and imaginary parts independently. Note: This is not holomorphic but useful for some neural architectures.

This specialized activation is designed for complex-valued neural networks in spectral transformers where non-linearity is needed in both components.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required

Returns:

Type Description
Tensor

Complex tensor with ReLU applied to each part.

Raises:

Type Description
TypeError

If input is not a complex tensor.

Source code in spectrans/utils/complex.py
def complex_relu(x: Tensor) -> Tensor:
    """Apply ReLU activation to complex tensor.

    Applies ReLU to both real and imaginary parts independently.
    Note: This is not holomorphic but useful for some neural architectures.

    This specialized activation is designed for complex-valued neural networks
    in spectral transformers where non-linearity is needed in both components.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.

    Returns
    -------
    Tensor
        Complex tensor with ReLU applied to each part.

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    real_part = torch.real(x)
    imag_part = torch.imag(x)

    real_relu = torch.relu(real_part)
    imag_relu = torch.relu(imag_part)

    return torch.complex(real_relu, imag_relu)

make_complex

make_complex(real: Tensor, imag: Tensor) -> Tensor

Construct complex tensor from real and imaginary parts.

Fundamental constructor for complex tensors in spectral transforms.

Parameters:

Name Type Description Default
real Tensor

Real part tensor.

required
imag Tensor

Imaginary part tensor.

required

Returns:

Type Description
Tensor

Complex tensor.

Raises:

Type Description
TypeError

If inputs are not real tensors.

RuntimeError

If tensors cannot be broadcast together.

Source code in spectrans/utils/complex.py
def make_complex(real: Tensor, imag: Tensor) -> Tensor:
    """Construct complex tensor from real and imaginary parts.

    Fundamental constructor for complex tensors in spectral transforms.

    Parameters
    ----------
    real : Tensor
        Real part tensor.
    imag : Tensor
        Imaginary part tensor.

    Returns
    -------
    Tensor
        Complex tensor.

    Raises
    ------
    TypeError
        If inputs are not real tensors.
    RuntimeError
        If tensors cannot be broadcast together.
    """
    if real.is_complex():
        raise TypeError(f"Real part must be real tensor, got {real.dtype}")
    if imag.is_complex():
        raise TypeError(f"Imaginary part must be real tensor, got {imag.dtype}")

    try:
        return torch.complex(real, imag)
    except RuntimeError as e:
        raise RuntimeError(
            f"Cannot broadcast tensors with shapes {real.shape} and {imag.shape}"
        ) from e

split_complex

split_complex(x: Tensor) -> tuple[Tensor, Tensor]

Split complex tensor into real and imaginary parts.

Useful for separate processing of real and imaginary components in spectral neural networks and filter implementations.

Parameters:

Name Type Description Default
x Tensor

Input complex tensor.

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (real_part, imaginary_part) tensors.

Raises:

Type Description
TypeError

If input is not a complex tensor.

Source code in spectrans/utils/complex.py
def split_complex(x: Tensor) -> tuple[Tensor, Tensor]:
    """Split complex tensor into real and imaginary parts.

    Useful for separate processing of real and imaginary components
    in spectral neural networks and filter implementations.

    Parameters
    ----------
    x : Tensor
        Input complex tensor.

    Returns
    -------
    tuple[Tensor, Tensor]
        Tuple of (real_part, imaginary_part) tensors.

    Raises
    ------
    TypeError
        If input is not a complex tensor.
    """
    if not x.is_complex():
        raise TypeError(f"Input must be complex tensor, got {x.dtype}")

    return torch.real(x), torch.imag(x)

complex_kaiming_init

complex_kaiming_init(tensor: Tensor, gain: float = 1.0, mode: Literal['fan_in', 'fan_out'] = 'fan_in') -> Tensor

Kaiming initialization for complex tensors.

Parameters:

Name Type Description Default
tensor Tensor

Complex tensor to initialize.

required
gain float

Scaling factor for initialization.

1.0
mode (fan_in, fan_out)

Fan mode for variance calculation.

"fan_in"

Returns:

Type Description
Tensor

Initialized complex tensor.

Raises:

Type Description
TypeError

If tensor is not complex.

ValueError

If tensor dimensions or parameters are invalid.

Source code in spectrans/utils/initialization.py
def complex_kaiming_init(
    tensor: Tensor, gain: float = 1.0, mode: Literal["fan_in", "fan_out"] = "fan_in"
) -> Tensor:
    """Kaiming initialization for complex tensors.

    Parameters
    ----------
    tensor : Tensor
        Complex tensor to initialize.
    gain : float, default=1.0
        Scaling factor for initialization.
    mode : {"fan_in", "fan_out"}, default="fan_in"
        Fan mode for variance calculation.

    Returns
    -------
    Tensor
        Initialized complex tensor.

    Raises
    ------
    TypeError
        If tensor is not complex.
    ValueError
        If tensor dimensions or parameters are invalid.
    """
    if not tensor.is_complex():
        raise TypeError(f"Tensor must be complex, got {tensor.dtype}")

    if tensor.ndim < 2:
        raise ValueError(f"Kaiming initialization requires at least 2D tensor, got {tensor.ndim}D")

    if gain <= 0:
        raise ValueError(f"Gain must be positive, got {gain}")

    if mode not in ("fan_in", "fan_out"):
        raise ValueError(f"Mode must be 'fan_in' or 'fan_out', got {mode}")

    # Calculate Kaiming scaling for complex tensors
    fan_in = tensor.shape[-2]
    fan_out = tensor.shape[-1]
    fan = fan_in if mode == "fan_in" else fan_out
    std = gain / math.sqrt(fan)  # Adjusted for complex

    return complex_normal_init(tensor, std)

complex_normal_init

complex_normal_init(tensor: Tensor, std: float = 1.0) -> Tensor

Initialize complex tensor with complex normal distribution.

Both real and imaginary parts are initialized independently with normal distribution scaled to maintain proper variance.

Parameters:

Name Type Description Default
tensor Tensor

Complex tensor to initialize.

required
std float

Standard deviation for each component.

1.0

Returns:

Type Description
Tensor

Initialized complex tensor.

Raises:

Type Description
TypeError

If tensor is not complex.

ValueError

If std is not positive.

Source code in spectrans/utils/initialization.py
def complex_normal_init(tensor: Tensor, std: float = 1.0) -> Tensor:
    """Initialize complex tensor with complex normal distribution.

    Both real and imaginary parts are initialized independently with
    normal distribution scaled to maintain proper variance.

    Parameters
    ----------
    tensor : Tensor
        Complex tensor to initialize.
    std : float, default=1.0
        Standard deviation for each component.

    Returns
    -------
    Tensor
        Initialized complex tensor.

    Raises
    ------
    TypeError
        If tensor is not complex.
    ValueError
        If std is not positive.
    """
    if not tensor.is_complex():
        raise TypeError(f"Tensor must be complex, got {tensor.dtype}")

    if std <= 0:
        raise ValueError(f"Standard deviation must be positive, got {std}")

    # For complex normal: each component has std/sqrt(2) to maintain total variance
    component_std = std / math.sqrt(2)

    with torch.no_grad():
        # Initialize real and imaginary parts independently
        real_part = torch.randn_like(tensor.real) * component_std
        imag_part = torch.randn_like(tensor.imag) * component_std
        tensor.copy_(torch.complex(real_part, imag_part))

    return tensor

complex_xavier_init

complex_xavier_init(tensor: Tensor, gain: float = 1.0) -> Tensor

Xavier initialization for complex tensors.

Parameters:

Name Type Description Default
tensor Tensor

Complex tensor to initialize.

required
gain float

Scaling factor for initialization.

1.0

Returns:

Type Description
Tensor

Initialized complex tensor.

Raises:

Type Description
TypeError

If tensor is not complex.

ValueError

If tensor dimensions or gain are invalid.

Source code in spectrans/utils/initialization.py
def complex_xavier_init(tensor: Tensor, gain: float = 1.0) -> Tensor:
    """Xavier initialization for complex tensors.

    Parameters
    ----------
    tensor : Tensor
        Complex tensor to initialize.
    gain : float, default=1.0
        Scaling factor for initialization.

    Returns
    -------
    Tensor
        Initialized complex tensor.

    Raises
    ------
    TypeError
        If tensor is not complex.
    ValueError
        If tensor dimensions or gain are invalid.
    """
    if not tensor.is_complex():
        raise TypeError(f"Tensor must be complex, got {tensor.dtype}")

    if tensor.ndim < 2:
        raise ValueError(f"Xavier initialization requires at least 2D tensor, got {tensor.ndim}D")

    if gain <= 0:
        raise ValueError(f"Gain must be positive, got {gain}")

    # Calculate Xavier scaling for complex tensors
    fan_in = tensor.shape[-2]
    fan_out = tensor.shape[-1]
    std = gain * math.sqrt(1.0 / (fan_in + fan_out))  # Adjusted for complex

    return complex_normal_init(tensor, std)

dct_init

dct_init(tensor: Tensor) -> Tensor

Initialize tensor with DCT matrix properties.

Parameters:

Name Type Description Default
tensor Tensor

2D tensor to initialize.

required

Returns:

Type Description
Tensor

Initialized tensor with DCT-like structure.

Raises:

Type Description
ValueError

If tensor is not 2D.

Source code in spectrans/utils/initialization.py
def dct_init(tensor: Tensor) -> Tensor:
    """Initialize tensor with DCT matrix properties.

    Parameters
    ----------
    tensor : Tensor
        2D tensor to initialize.

    Returns
    -------
    Tensor
        Initialized tensor with DCT-like structure.

    Raises
    ------
    ValueError
        If tensor is not 2D.
    """
    if tensor.ndim != 2:
        raise ValueError(f"DCT initialization requires 2D tensor, got {tensor.ndim}D")

    n, m = tensor.shape

    with torch.no_grad():
        # Build DCT-II matrix
        dct_matrix = torch.zeros(n, m, device=tensor.device, dtype=tensor.dtype)

        for i in range(n):
            for j in range(m):
                if i == 0:
                    dct_matrix[i, j] = math.sqrt(1.0 / m)
                else:
                    dct_matrix[i, j] = math.sqrt(2.0 / m) * math.cos(
                        math.pi * i * (2 * j + 1) / (2 * m)
                    )

        tensor.copy_(dct_matrix)

    return tensor

frequency_init

frequency_init(tensor: Tensor, max_freq: float = 1.0) -> Tensor

Initialize tensor with frequency-domain aware values.

Initializes with small values at high frequencies and larger values at low frequencies, mimicking natural signal characteristics.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to initialize (typically frequency domain parameters).

required
max_freq float

Maximum frequency for scaling.

1.0

Returns:

Type Description
Tensor

Initialized tensor.

Raises:

Type Description
ValueError

If max_freq is not positive.

Source code in spectrans/utils/initialization.py
def frequency_init(tensor: Tensor, max_freq: float = 1.0) -> Tensor:
    """Initialize tensor with frequency-domain aware values.

    Initializes with small values at high frequencies and larger values
    at low frequencies, mimicking natural signal characteristics.

    Parameters
    ----------
    tensor : Tensor
        Tensor to initialize (typically frequency domain parameters).
    max_freq : float, default=1.0
        Maximum frequency for scaling.

    Returns
    -------
    Tensor
        Initialized tensor.

    Raises
    ------
    ValueError
        If max_freq is not positive.
    """
    if max_freq <= 0:
        raise ValueError(f"Max frequency must be positive, got {max_freq}")

    with torch.no_grad():
        # Create frequency-based scaling
        # Assume last dimension represents frequency bins
        freq_dim = tensor.shape[-1]
        freqs = torch.linspace(0, max_freq, freq_dim, device=tensor.device)

        # 1/f-like scaling (pink noise characteristic)
        scaling = 1.0 / (1.0 + freqs)

        # Broadcast scaling to tensor shape
        shape = [1] * tensor.ndim
        shape[-1] = freq_dim
        scaling = scaling.view(shape)

        # Initialize with normal then scale
        tensor.normal_(0, 1)
        tensor.mul_(scaling)

    return tensor

hadamard_init

hadamard_init(tensor: Tensor) -> Tensor

Initialize tensor with Hadamard matrix properties.

Parameters:

Name Type Description Default
tensor Tensor

Square tensor to initialize.

required

Returns:

Type Description
Tensor

Initialized tensor with Hadamard-like structure.

Raises:

Type Description
ValueError

If tensor is not square or not power-of-2 sized.

Source code in spectrans/utils/initialization.py
def hadamard_init(tensor: Tensor) -> Tensor:
    """Initialize tensor with Hadamard matrix properties.

    Parameters
    ----------
    tensor : Tensor
        Square tensor to initialize.

    Returns
    -------
    Tensor
        Initialized tensor with Hadamard-like structure.

    Raises
    ------
    ValueError
        If tensor is not square or not power-of-2 sized.
    """
    if tensor.ndim != 2:
        raise ValueError(f"Hadamard initialization requires 2D tensor, got {tensor.ndim}D")

    if tensor.shape[0] != tensor.shape[1]:
        raise ValueError(f"Hadamard initialization requires square tensor, got {tensor.shape}")

    size = tensor.shape[0]

    # Check if size is power of 2
    if size & (size - 1) != 0 or size == 0:
        raise ValueError(f"Hadamard initialization requires power-of-2 size, got {size}")

    with torch.no_grad():
        # Build Hadamard matrix recursively
        h = torch.tensor([[1.0]], device=tensor.device, dtype=tensor.dtype)

        while h.shape[0] < size:
            current_size = h.shape[0]
            new_h = torch.zeros(
                2 * current_size, 2 * current_size, device=tensor.device, dtype=tensor.dtype
            )
            new_h[:current_size, :current_size] = h
            new_h[:current_size, current_size:] = h
            new_h[current_size:, :current_size] = h
            new_h[current_size:, current_size:] = -h
            h = new_h

        # Normalize
        h = h / math.sqrt(size)
        tensor.copy_(h)

    return tensor

init_conv_spectral

init_conv_spectral(conv: Conv1d | Conv2d, method: str = 'kaiming') -> Conv1d | Conv2d

Initialize convolution layer with spectral-aware method.

Parameters:

Name Type Description Default
conv Conv1d | Conv2d

Convolution layer to initialize.

required
method str

Initialization method: "xavier", "kaiming".

"kaiming"

Returns:

Type Description
Conv1d | Conv2d

Initialized convolution layer.

Raises:

Type Description
ValueError

If method is not supported.

Source code in spectrans/utils/initialization.py
def init_conv_spectral(
    conv: nn.Conv1d | nn.Conv2d, method: str = "kaiming"
) -> nn.Conv1d | nn.Conv2d:
    """Initialize convolution layer with spectral-aware method.

    Parameters
    ----------
    conv : nn.Conv1d | nn.Conv2d
        Convolution layer to initialize.
    method : str, default="kaiming"
        Initialization method: "xavier", "kaiming".

    Returns
    -------
    nn.Conv1d | nn.Conv2d
        Initialized convolution layer.

    Raises
    ------
    ValueError
        If method is not supported.
    """
    if method == "xavier":
        xavier_spectral_init(conv.weight)
    elif method == "kaiming":
        kaiming_spectral_init(conv.weight, nonlinearity="relu")
    else:
        raise ValueError(f"Unsupported method: {method}")

    if conv.bias is not None:
        nn.init.zeros_(conv.bias)

    return conv

init_linear_spectral

init_linear_spectral(linear: Linear, method: str = 'xavier') -> Linear

Initialize linear layer with spectral-aware method.

Parameters:

Name Type Description Default
linear Linear

Linear layer to initialize.

required
method str

Initialization method: "xavier", "kaiming", "orthogonal".

"xavier"

Returns:

Type Description
Linear

Initialized linear layer.

Raises:

Type Description
ValueError

If method is not supported.

Source code in spectrans/utils/initialization.py
def init_linear_spectral(linear: nn.Linear, method: str = "xavier") -> nn.Linear:
    """Initialize linear layer with spectral-aware method.

    Parameters
    ----------
    linear : nn.Linear
        Linear layer to initialize.
    method : str, default="xavier"
        Initialization method: "xavier", "kaiming", "orthogonal".

    Returns
    -------
    nn.Linear
        Initialized linear layer.

    Raises
    ------
    ValueError
        If method is not supported.
    """
    if method == "xavier":
        xavier_spectral_init(linear.weight)
    elif method == "kaiming":
        kaiming_spectral_init(linear.weight)
    elif method == "orthogonal":
        orthogonal_spectral_init(linear.weight)
    else:
        raise ValueError(f"Unsupported method: {method}")

    if linear.bias is not None:
        nn.init.zeros_(linear.bias)

    return linear

kaiming_spectral_init

kaiming_spectral_init(tensor: Tensor, gain: float = 1.0, mode: Literal['fan_in', 'fan_out'] = 'fan_in', nonlinearity: str = 'relu') -> Tensor

Kaiming/He initialization adapted for spectral transforms.

Designed for networks with ReLU-like activations, maintaining variance through forward/backward passes.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to initialize.

required
gain float

Scaling factor for initialization.

1.0
mode (fan_in, fan_out)

Fan mode for variance calculation.

"fan_in"
nonlinearity str

Nonlinearity type for gain calculation.

"relu"

Returns:

Type Description
Tensor

Initialized tensor.

Raises:

Type Description
ValueError

If tensor has fewer than 2 dimensions, parameters are invalid.

Source code in spectrans/utils/initialization.py
def kaiming_spectral_init(
    tensor: Tensor,
    gain: float = 1.0,
    mode: Literal["fan_in", "fan_out"] = "fan_in",
    nonlinearity: str = "relu",
) -> Tensor:
    """Kaiming/He initialization adapted for spectral transforms.

    Designed for networks with ReLU-like activations, maintaining
    variance through forward/backward passes.

    Parameters
    ----------
    tensor : Tensor
        Tensor to initialize.
    gain : float, default=1.0
        Scaling factor for initialization.
    mode : {"fan_in", "fan_out"}, default="fan_in"
        Fan mode for variance calculation.
    nonlinearity : str, default="relu"
        Nonlinearity type for gain calculation.

    Returns
    -------
    Tensor
        Initialized tensor.

    Raises
    ------
    ValueError
        If tensor has fewer than 2 dimensions, parameters are invalid.
    """
    if tensor.ndim < 2:
        raise ValueError(f"Kaiming initialization requires at least 2D tensor, got {tensor.ndim}D")

    if gain <= 0:
        raise ValueError(f"Gain must be positive, got {gain}")

    if mode not in ("fan_in", "fan_out"):
        raise ValueError(f"Mode must be 'fan_in' or 'fan_out', got {mode}")

    # Calculate fan-in and fan-out
    fan_in = tensor.shape[-2] if tensor.ndim >= 2 else tensor.numel()
    fan_out = tensor.shape[-1] if tensor.ndim >= 2 else tensor.numel()

    fan = fan_in if mode == "fan_in" else fan_out

    # Nonlinearity-specific gains
    nonlinearity_gains = {
        "linear": 1.0,
        "relu": math.sqrt(2.0),
        "leaky_relu": math.sqrt(2.0 / (1 + 0.01**2)),
        "tanh": 5.0 / 3,
        "sigmoid": 1.0,
        "gelu": 1.0,
    }

    if nonlinearity not in nonlinearity_gains:
        raise ValueError(f"Unsupported nonlinearity: {nonlinearity}")

    nl_gain = nonlinearity_gains[nonlinearity]
    std = gain * nl_gain / math.sqrt(fan)

    with torch.no_grad():
        tensor.normal_(0, std)

    return tensor

orthogonal_spectral_init

orthogonal_spectral_init(tensor: Tensor, gain: float = 1.0) -> Tensor

Orthogonal initialization for spectral transform matrices.

Creates orthogonal matrices that preserve norms, which is important for spectral transforms that should maintain energy conservation.

Parameters:

Name Type Description Default
tensor Tensor

2D tensor to initialize.

required
gain float

Scaling factor for the orthogonal matrix.

1.0

Returns:

Type Description
Tensor

Initialized orthogonal tensor.

Raises:

Type Description
ValueError

If tensor is not 2D or gain is not positive.

Source code in spectrans/utils/initialization.py
def orthogonal_spectral_init(tensor: Tensor, gain: float = 1.0) -> Tensor:
    """Orthogonal initialization for spectral transform matrices.

    Creates orthogonal matrices that preserve norms, which is important
    for spectral transforms that should maintain energy conservation.

    Parameters
    ----------
    tensor : Tensor
        2D tensor to initialize.
    gain : float, default=1.0
        Scaling factor for the orthogonal matrix.

    Returns
    -------
    Tensor
        Initialized orthogonal tensor.

    Raises
    ------
    ValueError
        If tensor is not 2D or gain is not positive.
    """
    if tensor.ndim != 2:
        raise ValueError(f"Orthogonal initialization requires 2D tensor, got {tensor.ndim}D")

    if gain <= 0:
        raise ValueError(f"Gain must be positive, got {gain}")

    with torch.no_grad():
        nn.init.orthogonal_(tensor, gain=gain)

    return tensor

spectral_init

spectral_init(tensor: Tensor, mode: str = 'normal', gain: float = 1.0) -> Tensor

Initialize tensor with spectral-aware method.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to initialize.

required
mode str

Initialization mode: "normal", "uniform", "xavier", "kaiming", "orthogonal".

"normal"
gain float

Scaling factor for initialization.

1.0

Returns:

Type Description
Tensor

Initialized tensor.

Raises:

Type Description
ValueError

If mode is not supported or gain is not positive.

RuntimeError

If tensor is not 2D for orthogonal initialization.

Source code in spectrans/utils/initialization.py
def spectral_init(tensor: Tensor, mode: str = "normal", gain: float = 1.0) -> Tensor:
    """Initialize tensor with spectral-aware method.

    Parameters
    ----------
    tensor : Tensor
        Tensor to initialize.
    mode : str, default="normal"
        Initialization mode: "normal", "uniform", "xavier", "kaiming", "orthogonal".
    gain : float, default=1.0
        Scaling factor for initialization.

    Returns
    -------
    Tensor
        Initialized tensor.

    Raises
    ------
    ValueError
        If mode is not supported or gain is not positive.
    RuntimeError
        If tensor is not 2D for orthogonal initialization.
    """
    if gain <= 0:
        raise ValueError(f"Gain must be positive, got {gain}")

    with torch.no_grad():
        if mode == "normal":
            # Standard normal initialization scaled by gain
            tensor.normal_(0, gain)
        elif mode == "uniform":
            # Uniform initialization in [-gain, gain]
            tensor.uniform_(-gain, gain)
        elif mode == "xavier":
            xavier_spectral_init(tensor, gain=gain)
        elif mode == "kaiming":
            kaiming_spectral_init(tensor, gain=gain)
        elif mode == "orthogonal":
            orthogonal_spectral_init(tensor, gain=gain)
        else:
            raise ValueError(f"Unsupported initialization mode: {mode}")

    return tensor

wavelet_init

wavelet_init(tensor: Tensor, wavelet_type: str = 'db1') -> Tensor

Initialize tensor with wavelet-like properties.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to initialize.

required
wavelet_type str

Type of wavelet initialization.

"db1"

Returns:

Type Description
Tensor

Initialized tensor.

Raises:

Type Description
ValueError

If wavelet_type is not supported.

Source code in spectrans/utils/initialization.py
def wavelet_init(tensor: Tensor, wavelet_type: str = "db1") -> Tensor:
    """Initialize tensor with wavelet-like properties.

    Parameters
    ----------
    tensor : Tensor
        Tensor to initialize.
    wavelet_type : str, default="db1"
        Type of wavelet initialization.

    Returns
    -------
    Tensor
        Initialized tensor.

    Raises
    ------
    ValueError
        If wavelet_type is not supported.
    """
    supported_wavelets = ["db1", "db2", "haar"]
    if wavelet_type not in supported_wavelets:
        raise ValueError(f"Wavelet type must be one of {supported_wavelets}, got {wavelet_type}")

    with torch.no_grad():
        if wavelet_type in ("db1", "haar"):
            # Haar/Daubechies-1 wavelet properties
            # Initialize with small random values then apply haar-like structure
            tensor.normal_(0, 0.1)

            # Apply alternating signs for wavelet-like behavior
            if tensor.ndim >= 2:
                for i in range(tensor.shape[-1]):
                    if i % 2 == 1:
                        tensor[..., i] *= -1
        elif wavelet_type == "db2":
            # Daubechies-2 initialization
            tensor.normal_(0, 0.1)
            # Apply more complex pattern for DB2
            if tensor.ndim >= 2:
                pattern = [1, -1, 1, -1]  # Simple DB2-like pattern
                for i in range(tensor.shape[-1]):
                    tensor[..., i] *= pattern[i % len(pattern)]

    return tensor

xavier_spectral_init

xavier_spectral_init(tensor: Tensor, gain: float = 1.0, distribution: Literal['normal', 'uniform'] = 'normal') -> Tensor

Xavier/Glorot initialization adapted for spectral transforms.

Maintains variance of activations and gradients across layers by scaling based on input and output dimensions.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to initialize.

required
gain float

Scaling factor for initialization.

1.0
distribution (normal, uniform)

Distribution to use for initialization.

"normal"

Returns:

Type Description
Tensor

Initialized tensor.

Raises:

Type Description
ValueError

If tensor has fewer than 2 dimensions, gain is not positive, or distribution is invalid.

Source code in spectrans/utils/initialization.py
def xavier_spectral_init(
    tensor: Tensor, gain: float = 1.0, distribution: Literal["normal", "uniform"] = "normal"
) -> Tensor:
    """Xavier/Glorot initialization adapted for spectral transforms.

    Maintains variance of activations and gradients across layers by scaling
    based on input and output dimensions.

    Parameters
    ----------
    tensor : Tensor
        Tensor to initialize.
    gain : float, default=1.0
        Scaling factor for initialization.
    distribution : {"normal", "uniform"}, default="normal"
        Distribution to use for initialization.

    Returns
    -------
    Tensor
        Initialized tensor.

    Raises
    ------
    ValueError
        If tensor has fewer than 2 dimensions, gain is not positive,
        or distribution is invalid.
    """
    if tensor.ndim < 2:
        raise ValueError(f"Xavier initialization requires at least 2D tensor, got {tensor.ndim}D")

    if gain <= 0:
        raise ValueError(f"Gain must be positive, got {gain}")

    if distribution not in ("normal", "uniform"):
        raise ValueError(f"Distribution must be 'normal' or 'uniform', got {distribution}")

    # Calculate fan-in and fan-out
    # For spectral transforms, consider all dimensions except the last as input
    # and the last as output (or vice versa for transpose operations)
    fan_in = tensor.shape[-2] if tensor.ndim >= 2 else tensor.numel()
    fan_out = tensor.shape[-1] if tensor.ndim >= 2 else tensor.numel()

    # Xavier scaling factor
    std = gain * math.sqrt(2.0 / (fan_in + fan_out))

    with torch.no_grad():
        if distribution == "normal":
            tensor.normal_(0, std)
        else:  # uniform
            bound = gain * math.sqrt(6.0 / (fan_in + fan_out))
            tensor.uniform_(-bound, bound)

    return tensor

circular_pad

circular_pad(x: Tensor, pad_amount: int, dim: int = -1) -> Tensor

Apply circular (periodic) padding.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pad_amount int

Number of elements to pad.

required
dim int

Dimension to pad along.

-1

Returns:

Type Description
Tensor

Circularly padded tensor.

Raises:

Type Description
ValueError

If pad_amount is negative or exceeds tensor size.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def circular_pad(x: Tensor, pad_amount: int, dim: int = -1) -> Tensor:
    """Apply circular (periodic) padding.

    Parameters
    ----------
    x : Tensor
        Input tensor.
    pad_amount : int
        Number of elements to pad.
    dim : int, default=-1
        Dimension to pad along.

    Returns
    -------
    Tensor
        Circularly padded tensor.

    Raises
    ------
    ValueError
        If pad_amount is negative or exceeds tensor size.
    IndexError
        If dimension is out of bounds.
    """
    if pad_amount < 0:
        raise ValueError(f"Pad amount must be non-negative, got {pad_amount}")

    if pad_amount == 0:
        return x

    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    dim = dim % x.ndim
    seq_len = x.shape[dim]

    if pad_amount > seq_len:
        raise ValueError(f"Circular pad amount {pad_amount} exceeds tensor size {seq_len}")

    # Take last pad_amount elements and append them
    slices = [slice(None)] * x.ndim
    slices[dim] = slice(-pad_amount, None)
    padding = x[tuple(slices)]

    return torch.cat([x, padding], dim=dim)

pad_for_convolution

pad_for_convolution(x: Tensor, kernel_size: int, dim: int = -1, mode: str = 'zero') -> Tensor

Pad tensor for valid convolution without size reduction.

Applies symmetric padding to both sides of the specified dimension to ensure that convolution output has the same size as input (same padding).

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
kernel_size int

Size of convolution kernel.

required
dim int

Dimension to pad along.

-1
mode str

Padding mode: "zero", "circular", "reflect", "symmetric".

"zero"

Returns:

Type Description
Tensor

Padded tensor suitable for convolution.

Raises:

Type Description
ValueError

If kernel_size is not positive odd integer or mode is invalid.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def pad_for_convolution(x: Tensor, kernel_size: int, dim: int = -1, mode: str = "zero") -> Tensor:
    """Pad tensor for valid convolution without size reduction.

    Applies symmetric padding to both sides of the specified dimension to ensure
    that convolution output has the same size as input (same padding).

    Parameters
    ----------
    x : Tensor
        Input tensor.
    kernel_size : int
        Size of convolution kernel.
    dim : int, default=-1
        Dimension to pad along.
    mode : str, default="zero"
        Padding mode: "zero", "circular", "reflect", "symmetric".

    Returns
    -------
    Tensor
        Padded tensor suitable for convolution.

    Raises
    ------
    ValueError
        If kernel_size is not positive odd integer or mode is invalid.
    IndexError
        If dimension is out of bounds.
    """
    if kernel_size <= 0 or kernel_size % 2 == 0:
        raise ValueError(f"Kernel size must be positive odd integer, got {kernel_size}")

    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    # Normalize negative dimension
    dim = dim % x.ndim

    # Calculate padding needed for 'same' convolution
    pad_amount = kernel_size // 2

    if pad_amount == 0:
        return x

    # Apply symmetric padding based on mode
    if mode == "zero":
        # Create left padding
        left_pad_shape = list(x.shape)
        left_pad_shape[dim] = pad_amount
        left_padding = torch.zeros(left_pad_shape, dtype=x.dtype, device=x.device)

        # Create right padding
        right_pad_shape = list(x.shape)
        right_pad_shape[dim] = pad_amount
        right_padding = torch.zeros(right_pad_shape, dtype=x.dtype, device=x.device)

        # Concatenate: left_padding + original + right_padding
        return torch.cat([left_padding, x, right_padding], dim=dim)

    elif mode == "circular":
        # Apply circular padding on both sides
        # Left padding: take last pad_amount elements
        left_slices = [slice(None)] * x.ndim
        left_slices[dim] = slice(-pad_amount, None)
        left_padding = x[tuple(left_slices)]

        # Right padding: take first pad_amount elements
        right_slices = [slice(None)] * x.ndim
        right_slices[dim] = slice(0, pad_amount)
        right_padding = x[tuple(right_slices)]

        return torch.cat([left_padding, x, right_padding], dim=dim)

    elif mode == "reflect":
        # Apply reflection padding on both sides
        seq_len = x.shape[dim]
        if 2 * pad_amount >= seq_len:
            raise ValueError(
                f"Reflect pad amount {2 * pad_amount} too large for tensor size {seq_len}"
            )

        # Left padding: reflect first pad_amount elements (excluding edges)
        left_slices = [slice(None)] * x.ndim
        left_slices[dim] = slice(1, pad_amount + 1)
        left_padding = torch.flip(x[tuple(left_slices)], dims=[dim])

        # Right padding: reflect last pad_amount elements (excluding edges)
        right_slices = [slice(None)] * x.ndim
        right_slices[dim] = slice(seq_len - pad_amount - 1, seq_len - 1)
        right_padding = torch.flip(x[tuple(right_slices)], dims=[dim])

        return torch.cat([left_padding, x, right_padding], dim=dim)

    elif mode == "symmetric":
        # Apply symmetric padding on both sides
        seq_len = x.shape[dim]
        if 2 * pad_amount > seq_len:
            raise ValueError(
                f"Symmetric pad amount {2 * pad_amount} too large for tensor size {seq_len}"
            )

        # Left padding: reflect first pad_amount elements (including edges)
        left_slices = [slice(None)] * x.ndim
        left_slices[dim] = slice(0, pad_amount)
        left_padding = torch.flip(x[tuple(left_slices)], dims=[dim])

        # Right padding: reflect last pad_amount elements (including edges)
        right_slices = [slice(None)] * x.ndim
        right_slices[dim] = slice(seq_len - pad_amount, seq_len)
        right_padding = torch.flip(x[tuple(right_slices)], dims=[dim])

        return torch.cat([left_padding, x, right_padding], dim=dim)

    else:
        raise ValueError(
            f"Invalid padding mode: {mode}. Must be one of: zero, circular, reflect, symmetric"
        )

pad_for_fft

pad_for_fft(x: Tensor, dim: int = -1) -> tuple[Tensor, int]

Pad tensor to optimal size for FFT computation.

Pads to next power of 2 for optimal FFT performance.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
dim int

Dimension to pad along.

-1

Returns:

Type Description
tuple[Tensor, int]

Tuple of (padded_tensor, original_length).

Raises:

Type Description
IndexError

If dimension is out of bounds.

ValueError

If tensor is empty along specified dimension.

Source code in spectrans/utils/padding.py
def pad_for_fft(x: Tensor, dim: int = -1) -> tuple[Tensor, int]:
    """Pad tensor to optimal size for FFT computation.

    Pads to next power of 2 for optimal FFT performance.

    Parameters
    ----------
    x : Tensor
        Input tensor.
    dim : int, default=-1
        Dimension to pad along.

    Returns
    -------
    tuple[Tensor, int]
        Tuple of (padded_tensor, original_length).

    Raises
    ------
    IndexError
        If dimension is out of bounds.
    ValueError
        If tensor is empty along specified dimension.
    """
    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    dim = dim % x.ndim
    original_length = x.shape[dim]

    if original_length == 0:
        raise ValueError("Cannot pad empty tensor dimension for FFT")

    # Find next power of 2
    optimal_length = pad_to_power_of_2(original_length)

    if optimal_length == original_length:
        return x, original_length

    padded = pad_to_length(x, optimal_length, dim, mode="zero")
    return padded, original_length

pad_sequence

pad_sequence(sequences: list[Tensor], padding_value: float = 0.0, dim: int = -1) -> Tensor

Pad a list of sequences to the same length.

Parameters:

Name Type Description Default
sequences list[Tensor]

List of tensors to pad.

required
padding_value float

Value to use for padding.

0.0
dim int

Dimension to pad along.

-1

Returns:

Type Description
Tensor

Batched tensor with all sequences padded to same length.

Raises:

Type Description
ValueError

If sequences is empty or tensors have incompatible shapes.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def pad_sequence(sequences: list[Tensor], padding_value: float = 0.0, dim: int = -1) -> Tensor:
    """Pad a list of sequences to the same length.

    Parameters
    ----------
    sequences : list[Tensor]
        List of tensors to pad.
    padding_value : float, default=0.0
        Value to use for padding.
    dim : int, default=-1
        Dimension to pad along.

    Returns
    -------
    Tensor
        Batched tensor with all sequences padded to same length.

    Raises
    ------
    ValueError
        If sequences is empty or tensors have incompatible shapes.
    IndexError
        If dimension is out of bounds.
    """
    if not sequences:
        raise ValueError("Cannot pad empty list of sequences")

    if len(sequences) == 1:
        return sequences[0].unsqueeze(0)

    # Check dimension bounds
    ndim = sequences[0].ndim
    if dim >= ndim or dim < -ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {ndim} dimensions")

    dim = dim % ndim

    # Verify all tensors have compatible shapes (same except for padding dimension)
    ref_shape = list(sequences[0].shape)
    for i, seq in enumerate(sequences[1:], 1):
        if seq.ndim != ndim:
            raise ValueError(
                f"All sequences must have same number of dimensions. "
                f"Sequence 0 has {ndim}, sequence {i} has {seq.ndim}"
            )

        seq_shape = list(seq.shape)
        for d in range(ndim):
            if d != dim and seq_shape[d] != ref_shape[d]:
                raise ValueError(
                    f"All sequences must have same shape except in padding dimension. "
                    f"Mismatch in dimension {d}: {ref_shape[d]} vs {seq_shape[d]}"
                )

    # Find maximum length
    max_length = max(seq.shape[dim] for seq in sequences)

    # Pad each sequence
    padded_sequences = []
    for seq in sequences:
        if seq.shape[dim] < max_length:
            pad_amount = max_length - seq.shape[dim]
            padded_seq = zero_pad(seq, pad_amount, dim, padding_value)
        else:
            padded_seq = seq
        padded_sequences.append(padded_seq)

    # Stack into batch
    return torch.stack(padded_sequences, dim=0)

pad_to_length

pad_to_length(x: Tensor, target_length: int, dim: int = -1, mode: str = 'zero') -> Tensor

Pad tensor to specified length along given dimension.

Parameters:

Name Type Description Default
x Tensor

Input tensor to pad.

required
target_length int

Target length after padding.

required
dim int

Dimension to pad along.

-1
mode str

Padding mode: "zero", "circular", "reflect", "symmetric".

"zero"

Returns:

Type Description
Tensor

Padded tensor with target length along specified dimension.

Raises:

Type Description
ValueError

If target_length is smaller than current length, or invalid mode.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def pad_to_length(x: Tensor, target_length: int, dim: int = -1, mode: str = "zero") -> Tensor:
    """Pad tensor to specified length along given dimension.

    Parameters
    ----------
    x : Tensor
        Input tensor to pad.
    target_length : int
        Target length after padding.
    dim : int, default=-1
        Dimension to pad along.
    mode : str, default="zero"
        Padding mode: "zero", "circular", "reflect", "symmetric".

    Returns
    -------
    Tensor
        Padded tensor with target length along specified dimension.

    Raises
    ------
    ValueError
        If target_length is smaller than current length, or invalid mode.
    IndexError
        If dimension is out of bounds.
    """
    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    # Normalize negative dimension
    dim = dim % x.ndim
    current_length = x.shape[dim]

    if target_length < current_length:
        raise ValueError(
            f"Target length {target_length} must be >= current length {current_length}"
        )

    if target_length == current_length:
        return x

    pad_amount = target_length - current_length

    # Convert mode to appropriate function
    if mode == "zero":
        return zero_pad(x, pad_amount, dim)
    elif mode == "circular":
        return circular_pad(x, pad_amount, dim)
    elif mode == "reflect":
        return reflect_pad(x, pad_amount, dim)
    elif mode == "symmetric":
        return symmetric_pad(x, pad_amount, dim)
    else:
        raise ValueError(
            f"Invalid padding mode: {mode}. Must be one of: zero, circular, reflect, symmetric"
        )

pad_to_power_of_2

pad_to_power_of_2(length: int) -> int

Find next power of 2 greater than or equal to length.

Parameters:

Name Type Description Default
length int

Input length.

required

Returns:

Type Description
int

Next power of 2.

Raises:

Type Description
ValueError

If length is not positive.

Source code in spectrans/utils/padding.py
def pad_to_power_of_2(length: int) -> int:
    """Find next power of 2 greater than or equal to length.

    Parameters
    ----------
    length : int
        Input length.

    Returns
    -------
    int
        Next power of 2.

    Raises
    ------
    ValueError
        If length is not positive.
    """
    if length <= 0:
        raise ValueError(f"Length must be positive, got {length}")

    if length == 1:
        return 1

    # Find next power of 2
    power = 1
    while power < length:
        power <<= 1

    return power

reflect_pad

reflect_pad(x: Tensor, pad_amount: int, dim: int = -1) -> Tensor

Apply reflection padding (mirror without repeating edge).

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pad_amount int

Number of elements to pad.

required
dim int

Dimension to pad along.

-1

Returns:

Type Description
Tensor

Reflection padded tensor.

Raises:

Type Description
ValueError

If pad_amount is negative or too large for reflection.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def reflect_pad(x: Tensor, pad_amount: int, dim: int = -1) -> Tensor:
    """Apply reflection padding (mirror without repeating edge).

    Parameters
    ----------
    x : Tensor
        Input tensor.
    pad_amount : int
        Number of elements to pad.
    dim : int, default=-1
        Dimension to pad along.

    Returns
    -------
    Tensor
        Reflection padded tensor.

    Raises
    ------
    ValueError
        If pad_amount is negative or too large for reflection.
    IndexError
        If dimension is out of bounds.
    """
    if pad_amount < 0:
        raise ValueError(f"Pad amount must be non-negative, got {pad_amount}")

    if pad_amount == 0:
        return x

    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    dim = dim % x.ndim
    seq_len = x.shape[dim]

    if pad_amount >= seq_len:
        raise ValueError(f"Reflect pad amount {pad_amount} must be < tensor size {seq_len}")

    # Reflect last pad_amount elements (excluding the edge)
    # For [1,2,3,4] with pad=2, we take elements [-3:-1] = [2,3] and flip to get [3,2]
    slices = [slice(None)] * x.ndim
    slices[dim] = slice(seq_len - pad_amount - 1, seq_len - 1)
    padding = torch.flip(x[tuple(slices)], dims=[dim])

    return torch.cat([x, padding], dim=dim)

symmetric_pad

symmetric_pad(x: Tensor, pad_amount: int, dim: int = -1) -> Tensor

Apply symmetric padding (mirror including edge).

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pad_amount int

Number of elements to pad.

required
dim int

Dimension to pad along.

-1

Returns:

Type Description
Tensor

Symmetrically padded tensor.

Raises:

Type Description
ValueError

If pad_amount is negative or too large for symmetry.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def symmetric_pad(x: Tensor, pad_amount: int, dim: int = -1) -> Tensor:
    """Apply symmetric padding (mirror including edge).

    Parameters
    ----------
    x : Tensor
        Input tensor.
    pad_amount : int
        Number of elements to pad.
    dim : int, default=-1
        Dimension to pad along.

    Returns
    -------
    Tensor
        Symmetrically padded tensor.

    Raises
    ------
    ValueError
        If pad_amount is negative or too large for symmetry.
    IndexError
        If dimension is out of bounds.
    """
    if pad_amount < 0:
        raise ValueError(f"Pad amount must be non-negative, got {pad_amount}")

    if pad_amount == 0:
        return x

    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    dim = dim % x.ndim
    seq_len = x.shape[dim]

    if pad_amount > seq_len:
        raise ValueError(f"Symmetric pad amount {pad_amount} must be <= tensor size {seq_len}")

    # Symmetric: mirror including the edge
    # For [1,2,3,4] with pad=2, we take last 2 elements [3,4] and flip to get [4,3]
    slices = [slice(None)] * x.ndim
    slices[dim] = slice(seq_len - pad_amount, seq_len)
    padding = torch.flip(x[tuple(slices)], dims=[dim])

    return torch.cat([x, padding], dim=dim)

unpad_sequence

unpad_sequence(padded_tensor: Tensor, lengths: list[int], dim: int = -1) -> list[Tensor]

Unpad a batched tensor back to individual sequences.

Parameters:

Name Type Description Default
padded_tensor Tensor

Batched, padded tensor.

required
lengths list[int]

Original lengths of each sequence.

required
dim int

Dimension that was padded.

-1

Returns:

Type Description
list[Tensor]

List of unpadded sequences.

Raises:

Type Description
ValueError

If lengths don't match batch size.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def unpad_sequence(padded_tensor: Tensor, lengths: list[int], dim: int = -1) -> list[Tensor]:
    """Unpad a batched tensor back to individual sequences.

    Parameters
    ----------
    padded_tensor : Tensor
        Batched, padded tensor.
    lengths : list[int]
        Original lengths of each sequence.
    dim : int, default=-1
        Dimension that was padded.

    Returns
    -------
    list[Tensor]
        List of unpadded sequences.

    Raises
    ------
    ValueError
        If lengths don't match batch size.
    IndexError
        If dimension is out of bounds.
    """
    if padded_tensor.ndim == 0:
        raise ValueError("Cannot unpad scalar tensor")

    batch_size = padded_tensor.shape[0]
    if len(lengths) != batch_size:
        raise ValueError(f"Number of lengths ({len(lengths)}) must match batch size ({batch_size})")

    # Adjust for batch dimension
    batch_dim = dim
    if batch_dim >= 0:
        batch_dim += 1  # Account for batch dimension

    if batch_dim >= padded_tensor.ndim or batch_dim < -padded_tensor.ndim:
        raise IndexError(
            f"Dimension {batch_dim} out of bounds for tensor with {padded_tensor.ndim} dimensions"
        )

    sequences = []
    for i, length in enumerate(lengths):
        seq = padded_tensor[i]
        if length < seq.shape[dim]:  # Use original dim for individual tensor
            seq = unpad_to_length(seq, length, dim)
        sequences.append(seq)

    return sequences

unpad_to_length

unpad_to_length(x: Tensor, target_length: int, dim: int = -1) -> Tensor

Remove padding to restore original length.

Parameters:

Name Type Description Default
x Tensor

Padded tensor.

required
target_length int

Original length before padding.

required
dim int

Dimension to unpad along.

-1

Returns:

Type Description
Tensor

Tensor with padding removed.

Raises:

Type Description
ValueError

If target_length is larger than current length.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def unpad_to_length(x: Tensor, target_length: int, dim: int = -1) -> Tensor:
    """Remove padding to restore original length.

    Parameters
    ----------
    x : Tensor
        Padded tensor.
    target_length : int
        Original length before padding.
    dim : int, default=-1
        Dimension to unpad along.

    Returns
    -------
    Tensor
        Tensor with padding removed.

    Raises
    ------
    ValueError
        If target_length is larger than current length.
    IndexError
        If dimension is out of bounds.
    """
    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    # Normalize negative dimension
    dim = dim % x.ndim
    current_length = x.shape[dim]

    if target_length > current_length:
        raise ValueError(
            f"Target length {target_length} must be <= current length {current_length}"
        )

    if target_length == current_length:
        return x

    # Create slice objects to extract the original data
    slices = [slice(None)] * x.ndim
    slices[dim] = slice(0, target_length)

    return x[tuple(slices)]

zero_pad

zero_pad(x: Tensor, pad_amount: int, dim: int = -1, value: float = 0.0) -> Tensor

Apply zero (constant) padding.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pad_amount int

Number of elements to pad.

required
dim int

Dimension to pad along.

-1
value float

Constant value to pad with.

0.0

Returns:

Type Description
Tensor

Zero-padded tensor.

Raises:

Type Description
ValueError

If pad_amount is negative.

IndexError

If dimension is out of bounds.

Source code in spectrans/utils/padding.py
def zero_pad(x: Tensor, pad_amount: int, dim: int = -1, value: float = 0.0) -> Tensor:
    """Apply zero (constant) padding.

    Parameters
    ----------
    x : Tensor
        Input tensor.
    pad_amount : int
        Number of elements to pad.
    dim : int, default=-1
        Dimension to pad along.
    value : float, default=0.0
        Constant value to pad with.

    Returns
    -------
    Tensor
        Zero-padded tensor.

    Raises
    ------
    ValueError
        If pad_amount is negative.
    IndexError
        If dimension is out of bounds.
    """
    if pad_amount < 0:
        raise ValueError(f"Pad amount must be non-negative, got {pad_amount}")

    if pad_amount == 0:
        return x

    if dim >= x.ndim or dim < -x.ndim:
        raise IndexError(f"Dimension {dim} out of bounds for tensor with {x.ndim} dimensions")

    dim = dim % x.ndim

    # Create padding tensor with same shape except in padding dimension
    pad_shape = list(x.shape)
    pad_shape[dim] = pad_amount

    padding = torch.full(pad_shape, value, dtype=x.dtype, device=x.device)

    return torch.cat([x, padding], dim=dim)