Quick Start Guide¶
This tutorial will get you up and running with Spectrans in minutes. We'll cover the essential patterns for using spectral transformers in your projects.
Prerequisites¶
Make sure you have Spectrans installed. If not, see the Installation Guide.
Your First Spectral Model¶
Let's start with the most straightforward approach - creating an FNet model for text classification:
import torch
from spectrans.models import FNet
# Create an FNet model for binary classification
model = FNet(
vocab_size=30000, # Size of your vocabulary
hidden_dim=768, # Hidden dimension (like BERT-base)
num_layers=12, # Number of transformer layers
max_sequence_length=512, # Maximum input length
num_classes=2 # Binary classification
)
# Create some sample input
batch_size = 4
seq_len = 128
input_ids = torch.randint(0, 30000, (batch_size, seq_len))
# Forward pass
logits = model(input_ids=input_ids)
print(f"Output shape: {logits.shape}") # (4, 2)
# Apply softmax for probabilities
probs = torch.softmax(logits, dim=-1)
print(f"Probabilities: {probs}")
Using Pre-Embedded Inputs¶
If you already have embeddings (from a pre-trained model, for example):
from spectrans.models import FNetEncoder
# For encoder-only usage (no classification head)
encoder = FNetEncoder(
hidden_dim=768,
num_layers=12,
max_sequence_length=512,
dropout=0.1
)
# Use with embeddings directly
embeddings = torch.randn(4, 128, 768) # (batch, seq_len, hidden_dim)
encoded = encoder(inputs_embeds=embeddings)
print(f"Encoded shape: {encoded.shape}") # (4, 128, 768)
Configuration-Based Models¶
Spectrans provides YAML-based configuration for easy experimentation:
from spectrans.config import ConfigBuilder
# Load a pre-defined configuration
builder = ConfigBuilder()
model = builder.build_model("examples/configs/fnet.yaml")
# The config file defines all model parameters
# You can also create custom configs
Here's what a configuration file looks like:
# examples/configs/my_fnet.yaml
model:
model_type: "fnet"
hidden_dim: 512
num_layers: 8
sequence_length: 256
dropout: 0.1
vocab_size: 20000
num_classes: 5
ffn_hidden_dim: 2048
use_real_fft: true
layers:
mixing:
type: "fourier"
hidden_dim: 512
dropout: 0.1
fft_norm: "ortho"
Exploring Different Architectures¶
GFNet (Global Filter Networks)¶
from spectrans.models import GFNet
# GFNet uses learnable filters in frequency domain
gfnet = GFNet(
vocab_size=30000,
hidden_dim=512,
num_layers=8,
max_sequence_length=256,
num_classes=10
)
input_ids = torch.randint(0, 30000, (2, 256))
logits = gfnet(input_ids=input_ids)
print(f"GFNet output: {logits.shape}") # (2, 10)
AFNO (Adaptive Fourier Neural Operator)¶
from spectrans.models import AFNOModel
# AFNO is great for tasks requiring global context
afno = AFNOModel(
vocab_size=30000,
hidden_dim=512,
num_layers=8,
max_sequence_length=256,
modes_seq=32, # Number of Fourier modes to keep
num_classes=10
)
# AFNO excels with longer sequences
input_ids = torch.randint(0, 30000, (4, 256))
logits = afno(input_ids=input_ids)
print(f"AFNO output: {logits.shape}") # (4, 10)
Wavelet Transformers¶
from spectrans.models import WaveletTransformer
# Wavelet transforms provide multi-resolution analysis
wavelet_model = WaveletTransformer(
vocab_size=30000,
hidden_dim=512,
num_layers=8,
wavelet="db4", # Daubechies-4 wavelet
levels=3, # 3 decomposition levels
max_sequence_length=256,
num_classes=10
)
input_ids = torch.randint(0, 30000, (2, 256))
logits = wavelet_model(input_ids=input_ids)
print(f"Wavelet output: {logits.shape}") # (2, 10)
Working with Individual Components¶
For more control, you can compose models from individual components:
from spectrans.layers.mixing import FourierMixing, WaveletMixing
from spectrans.blocks import SpectralTransformerBlock
from spectrans.models.base import BaseModel
# Create custom mixing layers
fourier_mixing = FourierMixing(hidden_dim=768, dropout=0.1)
wavelet_mixing = WaveletMixing(hidden_dim=768, wavelet="db6", levels=4)
# Create transformer blocks
fourier_block = SpectralTransformerBlock(
mixing_layer=fourier_mixing,
hidden_dim=768,
ffn_hidden_dim=3072
)
wavelet_block = SpectralTransformerBlock(
mixing_layer=wavelet_mixing,
hidden_dim=768,
ffn_hidden_dim=3072
)
# Combine in a custom model (you would need to implement the full model class)
Training Your Model¶
Here's a simple training loop example:
import torch
import torch.nn as nn
from torch.optim import AdamW
from spectrans.models import FNet
# Setup model and optimizer
model = FNet(
vocab_size=30000,
hidden_dim=256,
num_layers=6,
max_sequence_length=128,
num_classes=2
)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(3):
# Sample batch
input_ids = torch.randint(0, 30000, (8, 128))
labels = torch.randint(0, 2, (8,))
# Forward pass
logits = model(input_ids=input_ids)
loss = criterion(logits, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
GPU Usage¶
All models support GPU acceleration automatically:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = FNet(vocab_size=30000, hidden_dim=768, num_layers=12,
max_sequence_length=512, num_classes=2)
model = model.to(device)
# Data will need to be moved to GPU as well
input_ids = torch.randint(0, 30000, (4, 128)).to(device)
logits = model(input_ids=input_ids)
Memory Optimization¶
For large models or long sequences:
# Enable gradient checkpointing to save memory
model = FNet(vocab_size=30000, hidden_dim=1024, num_layers=24,
max_sequence_length=1024, num_classes=2,
gradient_checkpointing=True)
# Use mixed precision training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
logits = model(input_ids=input_ids)
loss = F.cross_entropy(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
What's Next?¶
Now that you've got the basics:
- Explore Examples: Check out the examples directory for complete applications
- Read the API Documentation: Detailed reference at API docs
- Experiment with Configurations: Modify the YAML configs in
examples/configs/to try different architectures - Contribute: See the Contributing Guide for development setup
- Build Custom Components: Learn how to create your own spectral layers
Key Concepts to Remember¶
- Direct Instantiation: Use model classes like
FNet(),GFNet()directly - Configuration: Use YAML files with
ConfigBuilderfor experiments - Components: Mix and match layers, blocks, and transforms
- Efficiency: Spectral methods offer \(O(n \log n)\) complexity vs \(O(n^2)\) for attention
- Flexibility: Works with both token IDs and pre-computed embeddings
Happy building with Spectrans! 🚀