Skip to main content

Overview

This guide shows you how to train Linear RNN models using lrnnx. The library provides easy-to-use training APIs for both time-invariant (LTI) and time-varying (LTV) models.

Quick Start

1

Import models

Import the model architectures from lrnnx:
from lrnnx.models.lti import LRU, S4, S4D, S5
2

Instantiate model

Create a model instance in training mode:
import torch
from lrnnx.models.lti import LRU

# Model parameters
d_model = 64      # Model dimension
d_state = 64      # State dimension

# Create model on CUDA
model = LRU(d_model=d_model, d_state=d_state).cuda()
model.train()  # Set to training mode
3

Create input tensors

Prepare your training data:
batch_size = 32
seq_len = 128
d_model = 64

# Create input tensor (B, L, H)
x = torch.randn(
    batch_size, seq_len, d_model,
    dtype=torch.float32,
    device="cuda"
)
All lrnnx models expect input of shape (batch_size, seq_len, d_model) where:
  • batch_size: Number of sequences in the batch
  • seq_len: Length of each sequence
  • d_model: Feature dimension
4

Forward and backward pass

Run the training loop:
import torch.nn as nn
import torch.optim as optim

# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(num_epochs):
    # Forward pass
    output = model(x)  # Shape: (batch_size, seq_len, d_model)

    # Compute loss (example: reconstruction)
    loss = criterion(output, x)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Complete Training Example

Here’s a complete example showing forward and backward passes:
import torch
import torch.nn as nn
import torch.optim as optim
from lrnnx.models.lti import LRU

# Model configuration
d_model = 64
d_state = 64
batch_size = 32
seq_len = 128

# Initialize model
model = LRU(d_model=d_model, d_state=d_state).cuda()
model.train()

# Create sample data
x = torch.randn(batch_size, seq_len, d_model, device="cuda")
target = torch.randn(batch_size, seq_len, d_model, device="cuda")

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Single training step
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")

Event-Based Training

Some models (S5, S6, Mamba) support event-based processing with custom integration timesteps:
from lrnnx.models.lti import S5

model = S5(d_model=64, d_state=64).cuda()
model.train()

x = torch.randn(32, 128, 64, device="cuda")

# Provide integration timesteps (B, L)
integration_timesteps = torch.rand(32, 128, device="cuda")

# Forward pass with custom timesteps
output = model(x, integration_timesteps=integration_timesteps)
When using integration_timesteps, ensure they are positive values representing the time intervals between events.

Benchmarking Training Performance

The library includes built-in benchmarking utilities to measure training throughput:
from benchmarks.benchmark_training import benchmark_sequence_length

def model_fn():
    return LRU(d_model=64, d_state=64).cuda()

# Benchmark across different sequence lengths
results = benchmark_sequence_length(
    model_fn,
    seq_lengths=[128, 256, 512, 1024, 2048],
    batch_size=32,
    repeats=5
)

for seq_len, times in results.items():
    avg_time = sum(times) / len(times)
    print(f"Seq len {seq_len}: {avg_time:.2f} ms")
See the full benchmarking suite in benchmarks/benchmark_training.py for more examples including:
  • Varying model dimensions
  • Varying batch sizes
  • Multi-run statistics

Mixed Precision Training

For faster training with lower memory usage, use automatic mixed precision:
import torch
from torch.cuda.amp import autocast, GradScaler
from lrnnx.models.ltv import Mamba

model = Mamba(d_model=64, d_state=16).cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()

x = torch.randn(32, 128, 64, device="cuda")
target = torch.randn(32, 128, 64, device="cuda")

for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass with autocast
    with autocast():
        output = model(x)
        loss = nn.functional.mse_loss(output, target)

    # Backward pass with gradient scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Available Models

All models follow the same training interface:

LTI Models (Time-Invariant)

  • S4 - Structured State Space model
  • S4D - Diagonal variant of S4
  • S5 - Simplified State Space model
  • LRU - Linear Recurrent Unit

LTV Models (Time-Varying)

  • Mamba - Selective State Space model
  • S6/S7 - Extensions with selective mechanisms
  • RGLRU - Recurrent Gated Linear Recurrent Unit

Next Steps

Inference Guide

Learn about fast autoregressive generation with CUDA graphs

Custom Kernels

Understand the CUDA kernels powering lrnnx