NVIDIA Transformer Engine is the library that makes FP8 training practical on H100 and H200 hardware. It handles the three things that raw PyTorch FP8 casting does not: per-tensor dynamic scaling, amax history tracking, and efficient FP8 GEMMs via cuBLAS. If you are running training or fine-tuning jobs on H100 SXM5 instances on Spheron and not yet using Transformer Engine, you are likely leaving 1.3-1.7x throughput on the table. This post covers the full setup: installation, PyTorch and JAX integration, the recipe API, training benchmarks, and the most common pitfalls teams run into in production.
What NVIDIA Transformer Engine Actually Does
FP8 training is not just calling model = model.to(torch.float8_e4m3fn). That approach fails in practice because FP8 has a narrow dynamic range: values outside that range overflow to infinity or underflow to zero. Training gradients span many orders of magnitude, so naive casting causes NaN losses within a few hundred steps.
Transformer Engine solves this with three components working together:
- Per-tensor dynamic scaling. Before each GEMM, TE computes a scale factor based on the tensor's historical maximum absolute value. The scale maps the tensor's actual range into the FP8 representable range. Without this, outlier activations corrupt entire layers.
- FP8 GEMMs via cuBLAS. The H100's fourth-generation Tensor Cores have dedicated FP8 matrix multiply units. TE calls these directly via cuBLAS, bypassing the overhead of PyTorch's generic dispatch path.
- Delayed scaling to amortize overhead. Computing a max-reduction over the full tensor before every GEMM would be expensive. TE's default
DelayedScalingrecipe maintains a rolling history of amax values and computes the scale from the previous iteration's statistics, adding almost zero per-step overhead.
| Format | Bits | Bytes/value | GPU support | Typical use |
|---|---|---|---|---|
| FP32 | 32 | 4 | All CUDA | Optimizer states, master weights |
| BF16 | 16 | 2 | Ampere+ | Default training precision |
| FP8 E4M3 | 8 | 1 | Hopper, Blackwell | Forward pass weights and activations |
| FP8 E5M2 | 8 | 1 | Hopper, Blackwell | Backward pass gradients |
The E4M3 format (4 exponent bits, 3 mantissa bits) gives higher precision at the cost of narrower dynamic range, making it right for weights and activations. E5M2 (5 exponent bits, 2 mantissa bits) covers a wider range with less precision, which suits gradients better.
For the full H100 hardware context, see the H100's FP8 Tensor Core specs covering the fourth-gen Tensor Core architecture and all precision tiers.
FP8 on H100 vs H200 vs B200: What Changes
The TE recipe API is identical across Hopper and Blackwell. You write the same te.fp8_autocast context manager, the same DelayedScaling recipe, the same layer replacements. Only the underlying GEMM throughput changes.
| GPU | Architecture | HBM capacity | FP8 TFLOPS (dense) | Memory bandwidth | TE use case |
|---|---|---|---|---|---|
| H100 SXM5 | Hopper | 80 GB HBM3 | 1,979 TFLOPS | 3.35 TB/s | Training up to ~30B, FP8 inference serving |
| H200 SXM5 | Hopper | 141 GB HBM3e | 1,979 TFLOPS | 4.8 TB/s | Training 70B+, memory-bandwidth-bound inference |
| B200 SXM6 | Blackwell | 192 GB HBM3e | 4,500 TFLOPS | 8.0 TB/s | Large-scale training, FP4 capable |
The H200's main advantage for TE workloads is memory capacity and bandwidth. At 141 GB, you can fit a 70B model in FP8 on a single GPU with room for KV cache. The H200 SXM5 memory bandwidth advantage also reduces memory-bound bottlenecks for training runs that are HBM-limited rather than compute-limited.
For teams targeting Blackwell, FP4 on Blackwell is the next step beyond FP8, with the B200 delivering roughly 2x the throughput of FP8 at the cost of more quantization error.
Installation
The fastest path is the NVIDIA NGC container, which ships with Transformer Engine pre-installed and CUDA extensions pre-compiled:
docker pull nvcr.io/nvidia/pytorch:24.09-py3
docker run --gpus all -it nvcr.io/nvidia/pytorch:24.09-py3For a pip install into an existing environment, CUDA 12.1 or later is required:
pip install transformer-engine[pytorch]For JAX support:
pip install transformer-engine[jax]The pip install compiles CUDA extensions on first run. This takes 5-15 minutes depending on hardware. If CUDA dev headers are not available in your container, the compile step will fail. Use NGC containers to avoid this entirely.
Verify the install:
import transformer_engine
print(transformer_engine.__version__)Code examples in this post are written against Transformer Engine 1.x (mid-2026). The TE 0.x API used different recipe class names. Confirm your version before copy-pasting examples.
PyTorch Integration: te.Linear, te.LayerNorm, te.TransformerLayer
The TE PyTorch API is a drop-in replacement for standard PyTorch modules. Same constructor arguments, same forward signature.
Custom TransformerBlock with TE layers:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
class TransformerBlock(torch.nn.Module):
def __init__(self, hidden_size, ffn_hidden_size, num_heads):
super().__init__()
self.ln1 = te.LayerNorm(hidden_size)
# attn_qkv omitted; QKV projection and attention scores are simplified away below
self.attn_proj = te.Linear(hidden_size, hidden_size, bias=True)
self.ln2 = te.LayerNorm(hidden_size)
self.ffn1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)
self.ffn2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)
self.act = torch.nn.GELU()
def forward(self, x):
# Attention (simplified stand-in: QKV projection and attention scores omitted)
residual = x
x = self.ln1(x)
x = self.attn_proj(x) # real code would compute Q*K^T*V here first
x = x + residual
# FFN
residual = x
x = self.ln2(x)
x = self.act(self.ffn1(x))
x = self.ffn2(x) + residual
return xDelayedScaling recipe construction:
from transformer_engine.common.recipe import DelayedScaling, Format
fp8_recipe = DelayedScaling(
fp8_format=Format.HYBRID, # E4M3 forward, E5M2 backward
amax_history_len=16, # rolling window of amax values
amax_compute_algo="max", # use max over history window
)Format.HYBRID is the default: E4M3 for forward-pass tensors (weights, activations), E5M2 for backward-pass gradients. Use Format.E4M3 if you want E4M3 everywhere (higher precision, smaller dynamic range for gradients).
Full training step with FP8 context:
model = TransformerBlock(hidden_size=4096, ffn_hidden_size=16384, num_heads=32).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(batch["input_ids"].cuda())
loss = criterion(output, batch["labels"].cuda())
loss.backward()
optimizer.step()Everything inside the fp8_autocast context uses FP8 GEMMs where possible. PyTorch autograd handles the backward pass through TE's custom CUDA kernels. Master weights and optimizer states stay in BF16 - TE does not touch them.
High-level alternative: te.TransformerLayer wraps the full attention + FFN block in one module. Use it if you do not need per-layer customization:
layer = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=16384,
num_attention_heads=32,
)JAX/Flax Integration
Install with JAX support:
pip install transformer-engine[jax]TE provides Flax module replacements under transformer_engine.jax.flax:
from flax import linen as nn
import transformer_engine.jax.flax as te_flax
import transformer_engine.jax as te_jax
from transformer_engine.common.recipe import DelayedScaling
recipe = DelayedScaling()
# Replace flax.linen.Dense with te_flax.DenseGeneral
class TELayer(nn.Module):
features: int
@nn.compact
def __call__(self, x):
with te_jax.fp8_autocast(enabled=True, fp8_recipe=recipe):
return te_flax.DenseGeneral(features=self.features)(x)JAX support as of TE 1.x is primarily DelayedScaling. Float8CurrentScaling support in JAX is more limited than in PyTorch. If you need stability-sensitive fine-tuning with JAX, plan for extra loss monitoring in early training steps.
Mixed-Precision Recipe: FP8 Forward + BF16 Weights + Gradient Accumulation
The standard TE mixed-precision setup keeps master weights and optimizer states in BF16, uses FP8 only for the compute-heavy GEMM operations:
| Tensor | Format | Why |
|---|---|---|
| Forward activations | FP8 E4M3 | High-throughput GEMM input, precision adequate |
| Weight matrix (compute) | FP8 E4M3 | Loaded fresh each forward pass |
| Weight master copy | BF16 | Accumulated by optimizer, needs precision |
| Gradients (backward GEMM) | FP8 E5M2 | Wide dynamic range for gradient values |
| Gradient accumulation buffer | FP32 | Prevents precision loss across micro-steps |
| Optimizer states (Adam m, v) | FP32 | Needed for stable weight updates |
This setup gives you FP8 GEMM throughput while keeping the optimizer numerically stable. The VRAM split: FP8 weights for compute take 1 byte/param, BF16 master weights take 2 bytes/param, FP32 optimizer states take 8 bytes/param. For a 7B model, that is roughly 7 + 14 + 56 = 77 GB total, still within H100's 80 GB with careful KV/activation management.
Float8CurrentScaling is the alternative recipe. Instead of delayed scaling, it computes the scale factor from the current tensor's actual max value before each GEMM. This eliminates the stale-scale instability risk but adds a max-reduction operation per tensor per step. For fine-tuning runs with aggressive learning rate schedules or small batch sizes, the stability improvement often outweighs the overhead:
from transformer_engine.common.recipe import Float8CurrentScaling
stable_recipe = Float8CurrentScaling()Use DelayedScaling for pre-training (stable loss landscape, large batches). Switch to Float8CurrentScaling if you see loss spikes in the first 100-200 steps of a fine-tuning run.
Benchmarks: BF16 vs FP8 Training Throughput on H100 SXM5
These are directional throughput estimates for dense transformer training on H100 SXM5 single-node. Real numbers vary with sequence length, batch size, and framework overhead.
| Model size | BF16 tokens/sec | FP8 tokens/sec | Speedup | VRAM delta |
|---|---|---|---|---|
| 7B (single GPU) | ~95,000 | ~140,000 | ~1.47x | -18 GB |
| 13B (single GPU) | ~48,000 | ~70,000 | ~1.45x | -32 GB |
| 70B (8-GPU NVLink) | ~18,000 | ~29,000 | ~1.61x | -140 GB total |
Larger models see higher speedups because they are more memory-bandwidth-bound. For a 70B model in BF16, each step loads 140 GB of weights across the NVLink fabric. FP8 cuts that to 70 GB, reducing bandwidth pressure significantly. Smaller models at small batch sizes are often compute-bound rather than memory-bound, so the FP8 FLOP-rate advantage does not fully materialize.
The VRAM delta is the reduction in activation and weight memory from switching to FP8. Master weights and optimizer states remain in BF16/FP32, so total VRAM does not drop by the full 50% weight-size reduction.
Benchmarks: FP8 Inference vs BF16 on H100 SXM5
For inference, these numbers represent single-GPU throughput at batch size 1, warm cache.
| Model | Precision | Tokens/sec (H100 SXM5) | VRAM used | Cost/1M tokens (on-demand) |
|---|---|---|---|---|
| Llama 3.3 70B | BF16 | ~110 | ~140 GB (2x H100) | ~$19.70/M |
| Llama 3.3 70B | FP8 | ~190 | ~73 GB (1x H100) | ~$5.70/M |
| Qwen3 72B | BF16 | ~105 | ~144 GB (2x H100) | ~$20.63/M |
| Qwen3 72B | FP8 | ~180 | ~75 GB (1x H100) | ~$6.02/M |
Cost/1M tokens is calculated as: (total $/hr across all GPUs) / (tokens/sec 3600) 1,000,000, using H100 SXM5 on-demand pricing of $3.90/hr/GPU. FP8 rows fit in 1x H100 ($3.90/hr); BF16 rows require 2x H100 ($7.80/hr total).
Pricing fluctuates based on GPU availability. The prices above are based on 22 May 2026 and may have changed. Check current GPU pricing for live rates.
For inference at scale, use vLLM or TRT-LLM directly rather than raw TE ops. Both add batching, KV cache management, and production serving infrastructure that raw TE does not provide. See vLLM FP8 inference deployment for the full production setup with Docker, tensor parallelism, and monitoring.
Common Pitfalls
Delayed scaling instability in early steps. The first 50-200 training steps use scale factors initialized from scratch. Outlier gradients in early steps can overflow the FP8 range, causing NaN loss before the amax history stabilizes. If your loss goes NaN in the first 100 steps, switch to Float8CurrentScaling or warm up with BF16 for 100 steps before enabling FP8.
Recipe mismatch between training and checkpoint. If you train with DelayedScaling and then run inference with a different recipe configuration, the scale statistics embedded in the checkpoint may not match. Always export checkpoints with the amax history cleared, or document the recipe used for each checkpoint.
Missing NVTE_FLASH_ATTN=1 env var. Transformer Engine's fused attention kernel requires this environment variable to be set. Without it, TE falls back to an unfused attention implementation that is significantly slower. Add export NVTE_FLASH_ATTN=1 to your training script or Docker entrypoint.
FP8 checkpoint export for TRT-LLM. If you plan to use a TE-trained checkpoint with TensorRT-LLM for production inference, the checkpoint must be exported in a format TRT-LLM understands. TE provides transformer_engine.pytorch.export_fp8_weights() for this. Not using the exporter leads to a manual weight conversion step that is error-prone.
TE with AMP (torch.autocast) simultaneously. Do not wrap TE layers in torch.autocast(dtype=torch.bfloat16) while also using te.fp8_autocast. The two context managers conflict: AMP intercepts GEMM dispatch before TE's FP8 path can, resulting in BF16 GEMMs instead of FP8. Use either AMP or TE's FP8 autocast, not both.
Transformer Engine vs vLLM / SGLang / TensorRT-LLM FP8 Paths
TE is a library. vLLM, SGLang, and TRT-LLM are serving runtimes. They are not competing choices for most teams - they operate at different levels.
| Use case | Recommended tool |
|---|---|
| Custom training loop or fine-tuning | Transformer Engine directly |
| Production inference serving | vLLM or SGLang (both use TE/cuBLAS FP8 ops internally) |
| Maximum inference throughput, offline | TensorRT-LLM (compiled engine, highest ceiling) |
| Research kernels, custom GEMM | Raw CUDA FP8 ops or Triton |
| Multi-modal or MoE training | NeMo (wraps TE) |
For production inference, go straight to vLLM vs TensorRT-LLM vs SGLang throughput benchmarks to pick the right serving runtime. Once you have picked one, the SGLang production deployment guide covers the full FP8 setup for SGLang.
Deploying Transformer Engine Workloads on Spheron H100 and H200
Both H100 SXM5 and H200 SXM5 are available on Spheron on-demand. Provisioning steps:
- Go to the Spheron dashboard and select a GPU instance.
- Choose H100 SXM5 (80 GB) for training up to 30B parameters or FP8 inference on 70B models.
- Choose H200 SXM5 (141 GB) for training 70B+ models or memory-bandwidth-bound inference.
- Select a base image with CUDA 12.1 or later. The NVIDIA NGC PyTorch container (
nvcr.io/nvidia/pytorch:24.09-py3) comes with TE pre-installed.
Current pricing (on-demand and spot):
| Instance | On-demand | Spot |
|---|---|---|
| H100 SXM5 (80 GB) | $3.90/hr | $1.66/hr |
| H200 SXM5 (141 GB) | $4.62/hr | $1.92/hr |
For memory-intensive training runs where the H200's 141 GB matters, see rent H200 SXM5 on Spheron for availability and instance configuration options.
Pricing fluctuates based on GPU availability. The prices above are based on 22 May 2026 and may have changed. Check current GPU pricing for live rates.
Transformer Engine FP8 workloads need Hopper hardware. Spheron H100 and H200 SXM5 instances are available on-demand with per-minute billing, no contracts, and a 5+ provider backend for availability redundancy.
Quick Setup Guide
Launch an H100 SXM5 or H200 SXM5 on-demand instance from the Spheron dashboard. Transformer Engine requires a Hopper or Blackwell GPU. Select a CUDA 12.1+ base image to ensure compatibility with Transformer Engine 1.x.
Install via pip: `pip install transformer-engine[pytorch]`. For JAX support add `transformer-engine[jax]`. The package compiles CUDA extensions on first install; ensure CUDA dev headers are available in the container image. Use the NVIDIA NGC container (nvcr.io/nvidia/pytorch:24.xx-py3) to skip compilation entirely.
Replace `torch.nn.Linear` with `transformer_engine.pytorch.Linear`, `torch.nn.LayerNorm` with `transformer_engine.pytorch.LayerNorm`, and use `transformer_engine.pytorch.TransformerLayer` for full attention+FFN blocks. The API is a drop-in replacement: same constructor arguments and forward signature.
Create a `DelayedScaling` recipe (default) or `Float8CurrentScaling` recipe via `transformer_engine.common.recipe`. Wrap your forward+backward pass in `te.fp8_autocast(enabled=True, fp8_recipe=recipe)` context manager. Master weights and optimizer states remain in BF16 automatically.
Compare BF16 and FP8 training runs on a held-out validation set. Expect 1.3-1.7x throughput improvement on H100 SXM5 for large transformer models. Check loss curves for divergence in the first 200 steps (delayed scaling artifact) and switch to `Float8CurrentScaling` if instability is observed.
Frequently Asked Questions
No. Transformer Engine requires Hopper (H100, H200) or Blackwell (B200, B300, RTX 5090) hardware. The RTX 4090 (Ada Lovelace) has partial FP8 support via CUDA but lacks the dedicated FP8 Tensor Cores and the hardware scaling infrastructure that Transformer Engine relies on. On Ada hardware, te.Linear will silently fall back to BF16 rather than error, which means you lose the throughput benefit without a clear warning.
E4M3 (4 exponent bits, 3 mantissa bits) has higher precision and a smaller dynamic range. It is the preferred format for forward pass activations and weights where accuracy matters most. E5M2 (5 exponent bits, 2 mantissa bits) has lower precision but a wider dynamic range, making it better suited for gradient accumulation in the backward pass where values can span many orders of magnitude. Transformer Engine's FP8Recipe lets you configure each layer's forward and backward formats independently.
Transformer Engine is a lower-level library. vLLM and TensorRT-LLM both use Transformer Engine (or equivalent CUDA FP8 ops) under the hood for FP8 inference on Hopper hardware. If you are doing inference at scale, use vLLM or TensorRT-LLM directly: they add batching, KV cache management, and production serving infrastructure on top of raw Transformer Engine ops. Transformer Engine is most useful when you are writing custom training loops, fine-tuning code, or need fine-grained control over the FP8 recipe per layer.
AMD's ROCm stack supports FP8 on MI300X hardware via the hipBLASLt FP8 gemm ops and the aiter library (AMD's fused attention kernels). There is no direct AMD equivalent of the full Transformer Engine Python API, but frameworks like PyTorch with ROCm can use FP8 compute on MI300X for inference. For training, AMD's support is less mature than Transformer Engine on Hopper as of mid-2026. See the AMD MI300X vs NVIDIA H200 post for a hardware-level comparison.
Delayed scaling is the default FP8 scaling strategy in Transformer Engine. Instead of computing the scale factor for each tensor on the fly (which would require a costly max-reduction before each GEMM), TE maintains a history of tensor max values and uses a scale computed from the previous iteration. This amortizes the overhead of scale computation over many steps. The downside: if tensor values change suddenly (e.g., at the start of training or after a learning rate spike), the scale from the previous step may be stale, causing FP8 overflow or underflow for 1-2 steps. For most training runs this is harmless, but for fine-tuning runs with aggressive schedules it can cause instability in the first few hundred steps.
