Public SAEs trained on base models are not the right tool if you are studying a fine-tuned or domain-adapted model. Anthropic's release covers Claude Sonnet and GPT-2 scale models. EleutherAI's open suite covers Pythia and LLaMA base weights. If your team is working on a fine-tuned, domain-adapted, or proprietary model like a Llama 3.1 70B trained on internal legal documents, or a code-specialized Qwen2.5 variant, those public SAEs will surface features that map to the base model's pretraining distribution, not the behaviors introduced by your fine-tuning. The features you care about will either be absent or masked by irrelevant ones.
The second problem is data sovereignty. SAE training requires capturing activations layer-by-layer from the base model across millions of tokens. These activation tensors encode the model's internal representations of your actual data. For most enterprise teams doing safety and governance work, shipping raw activation data to a third-party cloud for SAE training is ruled out by the same data processing agreements (DPAs) that govern model weights and training data. Self-hosted training on controlled infrastructure is the only viable path.
Safety teams doing AI red teaming increasingly need interpretability tools alongside attack and defense tooling. An SAE trained on your specific model gives you a feature-level map of what the model knows, which is more actionable than prompt-level red teaming alone.
SAE Training Workload Profile
Understanding the resource profile before provisioning saves money and prevents failed runs. SAE training looks like an LLM fine-tuning job from the outside but behaves very differently.
Activation Buffer Size
The math is straightforward. For a 70B model with d_model = 8192, each activation vector at a single layer is 8192 * 4 bytes = 32KB in float32, or 16KB in float16. If you want to train on 10 million tokens (a reasonable minimum for a large model), the full activation dataset is 10M * 32KB = 320GB in float32 or 160GB in float16.
You have two options: cache the full activation dataset to NVMe and load randomly during SAE training (maximum training throughput, expensive on storage), or stream activations from the base model during SAE training (slower, no storage cost for the buffer). For most teams starting out, streaming is simpler.
Compute Profile
The SAE forward pass is a large matmul (d_model x d_sae, where d_sae = expansion_factor * d_model) followed by TopK selection. At expansion_factor = 32 and d_model = 8192, the SAE hidden dimension is 262,144. The matmul is a BLAS-1 sized operation relative to the GPU's peak FLOP rate. GPU compute units are underutilized. VRAM bandwidth is the real bottleneck.
This matters for hardware selection: a newer GPU with more VRAM and higher memory bandwidth (H200, B200) outperforms a GPU with higher peak FLOPS but lower memory bandwidth per byte of model state.
Workload Phase Summary
| Workload Phase | GPU Bound By | Recommended GPU |
|---|---|---|
| Activation capture (7B-34B model) | VRAM capacity, compute | H100 SXM5 |
| Activation capture (70B+ model) | VRAM capacity (need model in memory) | H200 |
| SAE training (streaming) | VRAM bandwidth | H100 SXM5 or H200 |
| SAE training (cached, large buffer) | VRAM capacity + bandwidth | H200 or B200 |
| Multi-GPU SAE (FSDP) | Inter-GPU bandwidth (NVLink) | 2-4x H100 SXM5 NVLink |
Choosing an SAE Architecture
Four architectures are actively used in production interpretability work as of 2026:
| Architecture | Creator | Sparsity Mechanism | Interpretability Quality | Training Cost | Best For |
|---|---|---|---|---|---|
| TopK | OpenAI | Hard topk selection | High | Low | General use, starting point |
| JumpReLU | DeepMind | Learned threshold relu | Medium-high | Medium | Code, math token fidelity |
| Gated | Google DeepMind | Gating network | High (low polysemanticity) | High | Research on polysemanticity |
| Matryoshka | 2025 research | Hierarchical resolution | High (multi-scale) | Very high | Coarse+fine feature study |
TopK
TopK SAEs were introduced by Gao et al. (OpenAI, arXiv:2406.04093, June 2024) and have since been adopted by several interpretability labs. The sparsity constraint is explicit: exactly k features activate per token. Training is stable, dead feature rates are low at matched k, and automated interpretation quality is consistently higher than alternatives in head-to-head tests. Start here unless you have a specific reason not to.
JumpReLU
DeepMind's approach uses a learned threshold per feature rather than a global k. This produces better reconstruction fidelity on structured outputs like code and math, where token distributions are very different from natural language. The trade-off: interpretation scores are slightly lower because features are more polysemantic. Use it if you are studying a code model and reconstruction loss matters more than feature clarity.
Gated
Google DeepMind's Gated SAE (Rajamanoharan et al., arXiv:2404.16014, April 2024) adds a gating network that controls which features can activate, introducing an inductive bias toward monosemantic features. Training is harder (the gating network introduces instability) and the dead feature rate is harder to control. The payoff is that surviving features tend to be cleaner. Useful for research focused on polysemanticity rather than production interpretability pipelines.
Matryoshka
The Matryoshka approach (2025) trains multiple resolution levels simultaneously: coarse features (low k) and fine features (high k) from the same SAE weights. This is expensive but gives you a hierarchical feature map that is useful if you want to study both broad semantic categories and fine-grained token-level patterns. Training cost is roughly 3-4x a standard TopK run.
Hardware Sizing for SAE Training
The base model must fit in VRAM during activation capture. The SAE itself is small relative to the base model. The bottleneck is the base model size, not the SAE.
| Base Model | Layer Range | d_model | Activations per 1M Tokens | Recommended GPU | VRAM Headroom |
|---|---|---|---|---|---|
| 7B / 8B | 16-32 | 4096 | 8GB (float16) | H100 SXM5 | Large |
| Yi-34B | 32-48 | 7168 | 14GB (float16) | H100 SXM5 80GB | Comfortable |
| 70B | 48-80 | 8192 | 16GB (float16) per layer | H200 141GB | Comfortable |
| 70B full buffer | any | 8192 | 160GB (10M tokens, float16) | H200 or 2x H100 SXM5 | Use streaming |
| 70B cached buffer | any | 8192 | 320GB (10M tokens, float32) | 2x B200 or 4x H100 | Multi-GPU |
For most 70B interpretability work, a single H200 on Spheron is the right starting point. The 141GB HBM3e holds the 70B model in bfloat16 (~140GB) with a few gigabytes to spare for the activation buffer, leaving the SAE to run on a separate H100 SXM5 instance or on the same H200 after activation capture completes.
GPU Pricing (On-Demand and Spot)
| GPU | On-Demand (per GPU/hr) | Spot (per GPU/hr) | Typical SAE Training Run | Cost at On-Demand |
|---|---|---|---|---|
| H100 SXM5 | $4.00 | $1.69 | 12-24 hrs (7B-34B model) | $48-$96 |
| H200 SXM5 | $4.72 | $1.95 | 18-36 hrs (70B model, streaming) | $85-$170 |
| B200 SXM6 | $7.00 | $1.71 | 8-16 hrs (70B model, cached buffer) | $56-$112 |
Pricing fluctuates based on GPU availability. The prices above are based on 14 May 2026 and may have changed. Check current GPU pricing → for live rates.
For activation capture (the base model must stay in VRAM, no checkpointing), use on-demand instances. For the SAE training loop (checkpoint-friendly), spot instances at roughly 59% savings on H200 ($1.95 vs $4.72 on-demand) are the economical choice.
Step-by-Step Training Stack
1. Activation Capture Hook
Load the base model and register a forward hook on the residual stream output at your target layer. Write to safetensors shards as you go.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import save_file
from pathlib import Path
model_id = "meta-llama/Llama-3.1-70B"
layer_idx = 32 # Mid-to-late layers have the richest semantic features
shard_size = 50_000 # tokens per shard
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
activation_buffer = []
shard_index = 0
remainder = None # carries overflow activations from the previous batch
def hook_fn(module, input, output):
# output[0] is the residual stream tensor: [batch, seq, d_model]
activation_buffer.append(output[0].detach().float().cpu())
hook = model.model.layers[layer_idx].register_forward_hook(hook_fn)
output_dir = Path("activations/layer_32")
output_dir.mkdir(parents=True, exist_ok=True)
for batch in dataloader:
with torch.no_grad():
model(**batch)
gathered = torch.cat(activation_buffer, dim=1).view(-1, model.config.hidden_size)
activation_buffer.clear()
if remainder is not None:
gathered = torch.cat([remainder, gathered], dim=0)
while len(gathered) >= shard_size:
save_file(
{"activations": gathered[:shard_size].half()},
output_dir / f"shard_{shard_index:05d}.safetensors"
)
shard_index += 1
gathered = gathered[shard_size:]
remainder = gathered if len(gathered) > 0 else None
hook.remove()
# Flush final partial shard so no activations are silently dropped
if remainder is not None and len(remainder) > 0:
save_file(
{"activations": remainder.half()},
output_dir / f"shard_{shard_index:05d}.safetensors"
)For a 70B model at batch size 8 on a single H200, expect roughly 1.5 hours to capture 10M tokens at layer 32. Write shards to the instance's local NVMe, not to object storage, to avoid write latency stalls.
2. Streaming Activation Dataset
Once shards are on disk, wrap them in a PyTorch IterableDataset with a ring-buffer shuffle. This keeps memory usage flat regardless of dataset size.
import random
from torch.utils.data import IterableDataset
from safetensors.torch import load_file
class StreamingActivationDataset(IterableDataset):
def __init__(self, shard_paths, batch_size=512, buffer_size=8192):
self.shard_paths = shard_paths
self.batch_size = batch_size
self.buffer_size = buffer_size
def __iter__(self):
buffer = []
for path in self.shard_paths:
activations = load_file(path)["activations"].float()
buffer.extend(activations.unbind(0))
while len(buffer) >= self.buffer_size:
random.shuffle(buffer)
for i in range(0, (self.buffer_size // self.batch_size) * self.batch_size, self.batch_size):
yield torch.stack(buffer[i:i + self.batch_size])
buffer = buffer[(self.buffer_size // self.batch_size) * self.batch_size:]
# Flush remaining
random.shuffle(buffer)
for i in range(0, len(buffer) - self.batch_size + 1, self.batch_size):
yield torch.stack(buffer[i:i + self.batch_size])Set buffer_size to 5,000-10,000 samples. Larger buffers improve training convergence but cost more RAM on the host.
3. sae-lens Configuration
Install sae-lens with pip install sae-lens. Create a training config:
# sae_config.yaml
model_name: meta-llama/Llama-3.1-70B
hook_name: hook_resid_post
hook_layer: 32
architecture: topk
d_in: 8192 # d_model for Llama 70B
expansion_factor: 32 # d_sae = 32 * 8192 = 262,144 features
k: 64 # 64 features active per token
lr: 5.0e-5
lr_scheduler_name: cosine
lr_warm_up_steps: 1000
l1_coefficient: 0.0 # TopK uses hard sparsity, no L1 needed
train_batch_size_tokens: 4096
training_tokens: 100_000_000
normalize_activations: "expected_average_only_in"
checkpoint_path: ./checkpoints/sae_layer32
checkpoint_every_n_training_steps: 500The normalize_activations setting is critical for training stability. TopK SAEs benefit from normalizing by the expected average activation norm computed over a calibration set before training starts. sae-lens handles this automatically with this setting.
4. Training Launch
Single GPU:
python -m sae_lens.training.train_sae \
--config sae_config.yaml \
--activation_store_dir ./activations/layer_32Multi-GPU with FSDP across 4 H100 SXM5 nodes (see distributed LLM training with FSDP on GPU cloud for node setup):
torchrun --nproc_per_node=4 \
--master_addr=${MASTER_ADDR} \
--master_port=29500 \
-m sae_lens.training.train_sae_fsdp \
--config sae_config.yaml \
--activation_store_dir ./activations/layer_32 \
--fsdp_sharding_strategy FULL_SHARDFSDP shards the SAE weight tensors across GPUs. With expansion_factor = 32 and d_model = 8192, the SAE has ~4.3B parameters at ~8.6GB in float16. Across 4 H100s, each GPU holds ~2.1GB of SAE state, leaving most VRAM available for activation batches and optimizer states.
5. Checkpoint Strategy
The training loop saves checkpoints every 500 steps by default. On spot preemption, relaunch with:
torchrun --nproc_per_node=4 ... \
--resume_from_checkpoint ./checkpoints/sae_layer32/step_5000sae-lens checkpoints include optimizer state, so training resumes exactly from where it stopped. This makes spot instances viable for the SAE training phase. The activation capture phase is not checkpoint-friendly; run it on on-demand instances.
Evaluation
Three metrics determine whether a trained SAE is useful or needs to be retrained.
Reconstruction Loss
Normalized L2 reconstruction loss measures how well the SAE recovers the original activation from its sparse code:
normalized_L2 = ||x - SAE(x)||_2 / ||x||_2Target: below 0.05 for a well-trained TopK SAE. Values above 0.10 mean the SAE is losing meaningful information and features derived from it will be unreliable for steering.
Feature Density
The feature density histogram plots what fraction of tokens each feature activates on. A feature activating on more than 30% of tokens is polysemantic (it represents multiple concepts at once). A feature activating on fewer than 1e-5 of tokens is dead (it never fires).
| Architecture | Typical Dead Feature Rate at k=64 |
|---|---|
| TopK | 2-4% |
| JumpReLU | 5-12% |
| Gated | 3-6% (harder to control) |
| Matryoshka | 1-3% (fine resolution level) |
A dead feature rate above 5% for TopK SAEs is a signal to check your learning rate schedule or activation normalization. JumpReLU SAEs typically run at higher dead feature rates than TopK at matched k; treat 8-12% as acceptable for JumpReLU rather than applying the 5% threshold universally.
Automated Interpretation
For each feature, collect the 50 token sequences where that feature activated most strongly, then pass them to a lightweight LLM judge:
JUDGE_PROMPT = """
Below are 50 short text sequences. Each one caused feature {feature_id} to activate strongly
in a language model's residual stream.
Sequences:
{sequences}
Describe in exactly one sentence what concept or pattern this feature detects.
If the sequences share no clear pattern, respond with: UNCLEAR
"""Track the fraction of features where the judge produces a clear (non-UNCLEAR) description. A well-trained SAE on a 70B model should reach 70-85% interpretable features. Below 60% suggests the SAE is underdeveloped (more training tokens needed) or the expansion factor is too small.
Activation Steering at Inference Time
Once the SAE is trained, each feature corresponds to a specific direction in the model's residual stream. That direction is the column of the SAE's decoder weight matrix for that feature:
import numpy as np
import torch
from safetensors.torch import load_file
# Load trained SAE
sae_weights = load_file("./checkpoints/sae_layer32/final/model.safetensors")
W_dec = sae_weights["W_dec"] # shape: [d_sae, d_model]
# Feature 1234 direction in residual stream space
feature_id = 1234
feature_direction = W_dec[feature_id].float().numpy() # shape: [d_model]
np.save(f"./steering_vectors/feature_{feature_id}.npy", feature_direction)vLLM Forward Hook Integration
Register a hook on the target layer via vLLM's model internals:
import numpy as np
import torch
from vllm import LLM, SamplingParams
feature_direction = torch.tensor(
np.load("./steering_vectors/feature_1234.npy"), dtype=torch.bfloat16
).cuda()
steering_coefficient = 30.0 # Start here; adjust by sweeping [5, 100]
def steering_hook(module, input, output):
if isinstance(output, tuple):
hidden_states = output[0] + steering_coefficient * feature_direction
return (hidden_states,) + output[1:]
else:
return output + steering_coefficient * feature_direction
llm = LLM(model="meta-llama/Llama-3.1-70B", dtype="bfloat16")
# Register hook on layer 32
target_layer = llm.llm_engine.driver_worker.model_runner.model.model.layers[32]
hook = target_layer.register_forward_hook(steering_hook)
# Generate with steering active
outputs = llm.generate(
["Describe the meeting agenda"], SamplingParams(max_tokens=200)
)
hook.remove()A worked example: feature 1234 corresponds to "formal writing register" (identified by the LLM judge). At steering_coefficient = 30, the model's output shifts from conversational to formal without any system prompt changes. At steering_coefficient = 100, outputs become incoherent as the intervention overrides too many model computations.
Start at coefficient 30 and evaluate on a 50-prompt eval set. Move to 50 if the effect is too weak, or 15 if outputs are degraded. The sweet spot is usually between 20 and 60 for semantic steering at layer 32 in a 70B model.
For production use, integrate with a vLLM production deployment setup behind the standard OpenAI-compatible endpoint. The hook is registered per-request or globally, depending on whether you want steering to apply to all requests or only specific sessions.
EU AI Act Compliance
SAE training gives safety and compliance teams a concrete technical artifact for AI Act obligations, not a theoretical argument.
Article 53 of the EU AI Act requires general-purpose AI providers to publish "sufficiently detailed summaries" of training data and to document model capabilities. A trained SAE lets you map specific input patterns to specific model features and behaviors. This is a stronger interpretability artifact than attention visualization or logit attribution, which describe correlations rather than causal mechanisms inside the model. When regulators ask "what does the model do when it sees this input," an SAE feature map lets you point at specific internal states rather than describing emergent behavior.
Article 12 addresses logging and transparency for high-risk AI systems. SAE feature activations can be logged per-request as structured evidence of model behavior: which features activated, at what intensity, and how they map to known semantic categories. This is exactly the kind of machine-readable audit trail that Article 12 envisions.
The data sovereignty constraint is where self-hosted training becomes legally necessary, not just convenient. Many enterprise DPAs explicitly prohibit transmitting model activations (which are derivative of customer data) off-cluster to a third-party processor. Training the SAE on the same cluster where the base model lives, using data that stays within the infrastructure boundary, is the only path that fits under these agreements. EU-based teams using data center partners with EU-resident compute satisfy the geographic residency requirement that some DPAs now specify explicitly.
For the full compliance infrastructure picture, see our EU AI Act compliance guide for GPU cloud and the confidential GPU computing guide for regulated workloads.
Safety and interpretability teams that need to keep activations on-cluster for compliance reasons can run full SAE training pipelines on Spheron's H200 and B200 instances without shipping raw model internals to a third-party provider.
Quick Setup Guide
From the Spheron dashboard, launch one H200 or B200 instance for activation capture (needs the base model in full precision) and one or two H100 SXM5 instances for the SAE training loop. Tag both with the same project label for billing visibility.
Load the base model with Hugging Face Transformers and register a forward hook on the target layer's residual stream output. Stream a dataset (e.g., The Pile or RedPajama) through the model in inference mode, collecting hook outputs. Write shards to NVMe as float16 tensors using safetensors. For a 70B model at layer 32, plan for ~1.5 hours at batch size 8 on a single H200.
Use a PyTorch IterableDataset that reads shards from disk and yields activation batches. This avoids loading all activations into VRAM at once. Set buffer size to 5,000-10,000 samples and enable shuffle with a ring buffer. The streaming approach allows SAE training on datasets larger than VRAM.
Install sae-lens (pip install sae-lens). Configure the SAEConfig: set architecture to 'topk', expansion_factor to 32 (try 64 for finer-grained features), k to 32-64 sparsity, and lr to 5e-5 with cosine warmup. Pass your streaming dataset as the activation source. For multi-GPU with PyTorch FSDP, wrap the SAE model with FullyShardedDataParallel and shard across 2-4 H100 SXM5 nodes.
Track three metrics throughout training: (1) L2 reconstruction loss normalized by activation norm (target < 0.05), (2) feature density histogram (fraction of tokens each feature activates on - dead features have density < 1e-5 and should be under 5% of total), (3) mean number of active features per token (should match your k parameter). Use Weights & Biases or MLflow for tracking. Stop training when reconstruction loss plateaus and dead feature fraction stabilizes.
For each SAE feature, collect the top-50 activating token contexts. Pass them to a lightweight LLM (Llama 3.1 8B or similar) with the prompt: 'Given these token sequences that all cause feature N to activate strongly, describe in one sentence what concept or pattern this feature detects.' Store descriptions in a feature index. Features with unclear descriptions (the LLM says 'unclear' or gives contradictory answers) are candidates for dead feature analysis or architecture tuning.
Export the feature direction vector (SAE decoder weight column for the target feature, shape [d_model]) as a .npy file. In vLLM, register a custom model forward hook via the model's named_modules() API at the target layer. In the hook, add: residual_stream += steering_coefficient * feature_direction. Start with steering_coefficient in [10, 50] and reduce if the model becomes incoherent. Serve the steered model behind the standard OpenAI-compatible vLLM endpoint.
Frequently Asked Questions
A sparse autoencoder (SAE) is a one-hidden-layer network trained to reconstruct model activations through a sparse bottleneck. Public SAEs from Anthropic or EleutherAI are trained on base models like GPT-2 or Claude. If you are studying a fine-tuned, domain-adapted, or proprietary model, those SAEs do not transfer. Training your own on your target model's activations is the only way to get interpretable features for that specific model.
SAE training is GPU-memory-bound, not compute-bound. For a 70B base model generating activations at layer 32 (d_model 8192), each activation tensor is 8192 * 4 bytes = 32KB per token. A typical streaming dataset of 10M tokens therefore requires ~320GB for the full activation buffer in float32 (or ~160GB in float16). With streaming, this buffer is held on NVMe rather than in VRAM. Training the SAE itself on a 32x expansion factor needs ~2GB of parameters. You need the base model in memory to capture activations, so a single H200 (141GB) handles up to 34B base models; larger models require B200 (192GB) or multi-GPU FSDP.
TopK SAEs (introduced by OpenAI's Gao et al. 2024, arXiv:2406.04093) produce the most interpretable features and are easiest to evaluate. JumpReLU (DeepMind) trades some interpretability for better reconstruction fidelity on code and math tokens. Gated SAEs (Google DeepMind, Rajamanoharan et al. 2024) are harder to train but show lower feature polysemanticity. Matryoshka SAEs (2025) allow hierarchical resolution and are best if you want to study both coarse and fine-grained features from the same training run. For most interpretability teams starting out, TopK is the right default.
After training, each SAE feature corresponds to a direction in residual stream space. To steer the model, you identify the feature direction (a 1D vector in d_model space) for a target concept, scale it by a coefficient, and add it to the residual stream at the chosen layer during inference. With vLLM, this is implemented via a custom forward hook on the transformer block. The result is deterministic intervention on model behavior without prompt engineering or fine-tuning.
SAEs can serve as technical evidence for Article 53 (GPAI transparency obligations) and Article 12 (logging and interpretability). A trained SAE lets you point at specific features your model activates on inputs, which is stronger than attention visualization or logit attribution for explaining model behavior to regulators. They are not a substitute for full compliance documentation, but they are a defensible interpretability artifact.
SAE training runs are long (12-48 hours for large models) but are checkpoint-friendly. Spot instances at 40-60% cost reduction are viable if you checkpoint every 500-1000 steps and configure automatic resume. Activation capture against the base model is not checkpoint-friendly and should run on on-demand instances to avoid mid-capture interruption corrupting the dataset.
