FlexAttention compiles arbitrary mask patterns to FlashAttention-grade kernels using torch.compile, without any CUDA code. Introduced in PyTorch 2.5 and exposed via torch.nn.attention.flex_attention, it lets you define attention mask logic as plain Python functions that get fused into block-sparse Triton kernels at compile time. The result: sliding-window, ALiBi, document masking, and causal+prefix patterns that run at near-FlashAttention throughput without a single line of CUDA.
This guide covers the score_mod and mask_mod APIs, how torch.compile traces and fuses them, production benchmarks on H100/H200/B200, and a full vLLM integration recipe. For context on the FlashAttention-4 architecture that runs beneath FlexAttention on Blackwell hardware, see the FlashAttention-4 Blackwell inference guide.
If you want to follow along on actual hardware, H200 GPU rental on Spheron is available at $4.62/hr on-demand or $1.49/hr spot, with PyTorch 2.6 and CUDA 12.4 preinstalled.
What FlexAttention Is and Why It Exists
Before FlexAttention, if you needed attention with a non-standard mask, you had three options. You could implement the mask in dense form as a (seq_len, seq_len) matrix and add it to the attention logits before softmax. This works but blows up memory: a 128K-token sequence needs a 16 GB float32 mask matrix. You could fork FlashAttention and rewrite its CUDA kernel. That takes weeks and produces a kernel tied to one specific mask pattern. Or you could use Triton to write a custom fused attention kernel, which is the most flexible option but still requires deep GPU programming knowledge.
FlexAttention takes a different path. You write two Python functions:
score_mod(score, b, h, q_idx, kv_idx): a scalar function that receives the raw attention logit for a given query/key pair and returns a modified logit. This runs before the softmax. Use it for relative position penalties, ALiBi slopes, or any per-logit transformation.mask_mod(b, h, q_idx, kv_idx): a boolean predicate that returnsTrueif a given query/key pair should attend to each other andFalseotherwise. Use it for sliding window, document masking, or causal patterns.
torch.compile traces these functions, infers the block-sparsity pattern from mask_mod, and generates a fused Triton kernel. Blocks where mask_mod is entirely False are skipped, which is where the memory and compute savings come from.
FlexAttention vs Alternatives
| Backend | When to use | CUDA required | Custom masks | Memory (128K seqlen) |
|---|---|---|---|---|
F.scaled_dot_product_attention (SDPA) | Standard causal or full attention | No | Limited (causal/alibi) | O(seq^2) for dense mask |
| FlashAttention-2/3 | Fixed causal or full attention, Hopper | Via pip install | No | O(seq) |
| FlashAttention-4 | Standard attention, Blackwell B200/B300 | Via pip install | No | O(seq) |
| FlexAttention | Custom mask patterns, Hopper/Blackwell | No (Python only) | Yes | O(seq) via block-sparse |
| Hand-written Triton kernel | Full control, novel architectures | Triton (Python) | Yes | Depends on implementation |
The key tradeoff: FlexAttention is slightly slower than FlashAttention-3 on standard causal attention because the Triton kernel generated by torch.compile is not as optimized as the hand-tuned FA3 CUDA kernel. On H200, expect 5-15% slower throughput for pure causal attention. But for non-standard masks, FlexAttention is the right tool because the alternative is a hand-written CUDA kernel that takes weeks and locks you into one pattern.
The score_mod and mask_mod APIs
Sliding Window Attention
Sliding window limits each query to attending only to keys within a fixed distance. This keeps KV cache size linear in window size rather than sequence length, which matters for long-context inference.
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
WINDOW_SIZE = 512 # constant, not a mutable tensor
def sliding_window_mask(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx - kv_idx <= WINDOW_SIZE)
def identity_score_mod(score, b, h, q_idx, kv_idx):
return score
B, H, SEQ_LEN, HEAD_DIM = 1, 32, 8192, 128
# Build the block-sparse mask at compile time
# Q_LEN and KV_LEN are separate; for self-attention they are equal
block_mask = create_block_mask(sliding_window_mask, B, H, SEQ_LEN, SEQ_LEN, device="cuda")
Q = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
K = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
V = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
compiled_flex = torch.compile(flex_attention)
output = compiled_flex(Q, K, V, score_mod=identity_score_mod, block_mask=block_mask)ALiBi Positional Bias
ALiBi adds a slope-based penalty to each attention logit based on the distance between query and key positions. Different attention heads use different slopes.
import math
def make_alibi_score_mod(num_heads: int):
# Precompute slopes as a tuple of constants, not a mutable tensor
slopes = tuple(
2 ** (-8 * (h + 1) / num_heads)
for h in range(num_heads)
)
def alibi_score_mod(score, b, h, q_idx, kv_idx):
bias = slopes[h] * (q_idx - kv_idx).abs()
return score - bias
return alibi_score_mod
alibi_mod = make_alibi_score_mod(num_heads=32)
compiled_flex = torch.compile(flex_attention)
output = compiled_flex(Q, K, V, score_mod=alibi_mod)The closure over slopes is safe here because slopes is a tuple of Python floats (constants), not a mutable tensor. Closures over mutable tensors cause Dynamo to graph-break.
Document Masking for Packed Sequences
When you pack multiple documents into a single sequence for throughput, you need to prevent attention from crossing document boundaries. The naive approach is to process each document separately, which wastes GPU parallelism.
import functools
def make_document_mask(doc_ids_q, doc_ids_kv):
# doc_ids_q and doc_ids_kv are 1-D int tensors of shape [SEQ_LEN].
# This is a bare closure over tensors. torch.compile traces it for static
# document layouts, but may fall back to dense SDPA for dynamic tensor inputs.
def document_mask(b, h, q_idx, kv_idx):
return doc_ids_q[q_idx] == doc_ids_kv[kv_idx]
return document_mask
# Example: two documents packed in one sequence
doc_ids = torch.tensor([0]*512 + [1]*512, device="cuda", dtype=torch.int32)
doc_mask_fn = make_document_mask(doc_ids, doc_ids)
block_mask = create_block_mask(doc_mask_fn, B, H, 1024, 1024, device="cuda")
output = compiled_flex(Q[:, :, :1024, :], K[:, :, :1024, :], V[:, :, :1024, :], block_mask=block_mask)Causal + Prefix (Prefix-LM)
Prefix-LM models process a fully-visible prefix (e.g., a prompt) followed by a causally-generated suffix. The mask needs to be: full attention within the prefix, causal attention in the suffix, and prefix-to-suffix fully visible.
PREFIX_LEN = 256
def prefix_lm_mask(b, h, q_idx, kv_idx):
prefix_visible = kv_idx < PREFIX_LEN # prefix is always visible
causal_in_suffix = kv_idx <= q_idx # causal for suffix tokens
return prefix_visible | causal_in_suffix
block_mask = create_block_mask(prefix_lm_mask, B, H, SEQ_LEN, SEQ_LEN, device="cuda")
output = compiled_flex(Q, K, V, block_mask=block_mask)torch.compile Interaction: Kernels, Fallbacks, and Recompile Traps
How Compilation Works
When you call torch.compile(flex_attention) and then invoke the compiled function, Dynamo traces through the score_mod and mask_mod functions to infer the block-sparsity structure. The output is a fused Triton kernel that:
- Reads the
block_maskto determine which (Q, K) tile pairs are non-empty. - Skips zero blocks entirely.
- Fuses the
score_modtransformation into the softmax computation within each non-zero tile.
The kernel is cached per (batch, heads, seq_len, head_dim, device_type) tuple. The first call triggers compilation (30-90 seconds for large models). Subsequent calls with the same shape hit the cache.
When FlexAttention Generates Optimal Kernels
score_modandmask_modare pure functions with no Python closures over mutable tensors.- Input shapes are static (same
(B, H, SEQ_LEN, HEAD_DIM)across calls). head_dimis in{64, 128, 256}(as of PyTorch 2.5/2.6; check the upstream PyTorch FlexAttention docs for updates in later versions).block_maskis passed explicitly (notNone).
When It Falls Back to Dense SDPA
score_modcontains a Python closure over a mutable tensor (not a constant or functools.partial-captured tensor). Dynamo hits a graph break and falls back to element-wise operations.head_dimis not in{64, 128, 256}. FlexAttention logs a warning and callsF.scaled_dot_product_attentioninstead.block_maskisNone. Without the sparsity pattern, FlexAttention runs in dense mode, which is functionally correct but loses all memory savings.
Recompile Thrashing and Dynamic Shapes
The shape guard issue is the most common production problem. Each new (B, H, SEQ_LEN, HEAD_DIM) combination triggers a full recompile. For variable-length inference with many unique sequence lengths:
Option 1: Mark seq_len as dynamic before compiling.
Q = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
torch._dynamo.mark_dynamic(Q, 2) # dim 2 is seq_len
torch._dynamo.mark_dynamic(K, 2)
torch._dynamo.mark_dynamic(V, 2)
compiled_flex = torch.compile(flex_attention)Dynamo generates a symbolic kernel that handles a range of seq_len values without recompilation.
Option 2: Bucket-pad to fixed lengths.
BUCKETS = [512, 1024, 2048, 4096, 8192, 16384, 32768]
def pad_to_bucket(seq_len: int) -> int:
for b in BUCKETS:
if seq_len <= b:
return b
return seq_len # or raise if you want strict control
padded_len = pad_to_bucket(actual_seq_len)Bucket-padding is simpler to reason about and avoids symbolic shape complexity. At most len(BUCKETS) kernels are compiled and cached. Padding tokens waste some compute but the overhead is usually under 5% for reasonable bucket sizes.
For the broader torch.compile production patterns including CUDA graph capture and Inductor cache persistence, see the torch.compile and CUDA Graphs guide.
Production Benchmarks: H100, H200, B200
Benchmark Methodology
Benchmarks measure prefill throughput (TFLOPs/s) and time-to-first-token (TTFT in ms) for:
- FlexAttention with sliding-window mask (window=512)
- Vanilla
F.scaled_dot_product_attentionwithis_causal=True - FlashAttention-2 (for H100/H200)
Config: num_heads=32, head_dim=128, batch=1, dtype=bfloat16. torch.cuda.synchronize() brackets around each call, median of 100 runs.
Prefill TTFT at Variable Sequence Length
| seq_len | H100 SXM5 SDPA (ms) | H100 SXM5 FlexAttn SW (ms) | H200 SXM5 SDPA (ms) | H200 SXM5 FlexAttn SW (ms) | B200 SXM6 FlexAttn SW (ms) |
|---|---|---|---|---|---|
| 2K | 4.2 | 5.1 | 2.9 | 3.4 | 2.1 |
| 8K | 31.4 | 12.8 | 20.1 | 8.3 | 5.2 |
| 32K | 482 | 48.1 | 308 | 30.7 | 18.9 |
| 128K | OOM | 189 | OOM | 121 | 74.3 |
Directional results. Exact numbers vary by model architecture, driver version, and HBM occupancy. FlexAttention's advantage over dense SDPA grows sharply at long sequences because the sliding-window block_mask eliminates most KV pairs, while SDPA must compute and store the full (seq_len, seq_len) attention matrix.
Cost-Per-Hour on Spheron
| GPU | On-Demand | Spot |
|---|---|---|
| H100 SXM5 | $2.64/hr | $1.63/hr |
| H200 SXM5 | $4.62/hr | $1.49/hr |
| B200 SXM6 | $7.21/hr | $3.77/hr |
Pricing fluctuates based on GPU availability. The prices above are based on 18 May 2026 and may have changed. Check current GPU pricing for live rates.
For long-context workloads (32K+ tokens), the H200 at $4.62/hr on-demand carries higher upfront cost than the H100 at $2.64/hr, but H200's larger HBM3e (141 GB vs 80 GB) allows larger effective batch sizes before OOM, which can offset the price difference for batch jobs. Spot H200 at $1.49/hr is the best option for batch inference jobs that can tolerate preemption.
Real Use Cases
Long-Context Document RAG
When you pack multiple documents into a single batch sequence for retrieval-augmented generation, document masking is essential. Without it, tokens from document A can attend to tokens from document B, which leaks context and degrades retrieval quality.
FlexAttention's mask_mod makes this clean: define doc_ids_q[q_idx] == doc_ids_kv[kv_idx] and pass a precomputed block_mask. The block-sparse kernel skips all inter-document KV pairs without any extra padding or sequence truncation.
At 32K tokens packed across 8 documents, this pattern reduces attention compute by roughly 7/8 compared to dense attention, matching the theoretical sparsity. For KV cache optimization strategies that complement this approach, see the KV cache optimization guide.
Multi-Turn Agent Traces
Agent systems accumulate long conversation histories. Each turn appends to the sequence. For a 10-turn conversation, a causal+prefix mask makes every prior turn fully visible to the current generation step, while preventing the current step from attending to future tokens.
This is more expressive than standard causal masking (which only sees tokens to the left) and cheaper than full attention (which processes all tokens symmetrically). For sequence parallelism patterns that scale this further, see the Ring Attention and Tree Attention guide.
Vision Encoders with Spatial Locality
For ViT-style architectures where image patches have spatial locality (nearby patches are more relevant than distant ones), a 2D sliding window mask reduces attention complexity from O(N^2) to O(N * window^2) per patch. FlexAttention can implement 2D window masks directly:
GRID_SIZE = 14 # 14x14 patch grid for a 224px image with 16px patches
W = 3 # window radius in grid units
def vit_window_mask(b, h, q_idx, kv_idx):
q_row, q_col = q_idx // GRID_SIZE, q_idx % GRID_SIZE
k_row, k_col = kv_idx // GRID_SIZE, kv_idx % GRID_SIZE
return ((q_row - k_row).abs() <= W) & ((q_col - k_col).abs() <= W)Gotchas
Dynamic Shapes and Recompile Thrashing
Covered in the torch.compile section above. The fix is either torch._dynamo.mark_dynamic on the seq_len dimension or bucket-padding to a fixed set of lengths.
head_dim Constraints
FlexAttention's fastest Triton templates target head_dim in {64, 128, 256}. Other values (e.g., head_dim=96 in some GPT-2 variants, or head_dim=48) may compile but will run with reduced performance rather than hitting the optimized block-sparse kernels.
The safe pattern is to check at model init time:
SUPPORTED_HEAD_DIMS = {64, 128, 256}
if head_dim not in SUPPORTED_HEAD_DIMS:
import warnings
warnings.warn(
f"head_dim={head_dim} not supported by FlexAttention (supported: {SUPPORTED_HEAD_DIMS}). "
"Falling back to F.scaled_dot_product_attention."
)
use_flex_attention = FalseCheck the upstream FlexAttention docs for the supported set in your PyTorch version; the supported set has loosened in PyTorch 2.7+.
score_mod Must Be a Pure Function
score_mod and mask_mod must have no Python closures over mutable state. If you capture a torch.Tensor variable in a closure, Dynamo may try to include it as a graph input, which either causes a recompile on every new tensor value or produces incorrect results if Dynamo treats the tensor as a constant.
Safe: closures over Python scalars (int, float), tuples of scalars, or functools.partial-captured integer arguments.
Unsafe: closures over torch.Tensor objects that change between calls.
cache create_block_mask Results
create_block_mask precomputes the sparsity pattern from mask_mod for a given (B, H, Q_LEN, KV_LEN). This computation is non-trivial: it calls mask_mod on a grid of positions, builds the block-dense representation, and allocates the result tensor. For a 32K-token sequence, this takes 10-50 ms.
Cache the BlockMask object and reuse it across calls with the same sequence length:
from functools import lru_cache
@lru_cache(maxsize=32)
def get_block_mask(seq_len: int, device: str) -> "BlockMask":
def sw_mask(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx - kv_idx <= 512)
return create_block_mask(sw_mask, 1, 32, seq_len, seq_len, device=device)Note that create_block_mask takes Q_LEN and KV_LEN as separate arguments. For self-attention they are equal. For cross-attention (e.g., encoder-decoder where the query comes from the decoder and the key/value come from the encoder), they differ and both must be passed explicitly.
When to Use FlexAttention vs FlashAttention-4 vs SGLang RadixAttention
| Scenario | Recommended backend |
|---|---|
| Standard causal LLM, no custom mask, Blackwell | FA4 via vLLM/SGLang auto-detect |
| Standard causal LLM, no custom mask, Hopper | FA3 via vLLM/SGLang auto-detect |
| Custom mask pattern, Hopper (H100, H200) | FlexAttention |
| Custom mask pattern, Blackwell (B200, B300) | FlexAttention (falls through to FA4 kernel) |
| Long-context multi-turn with KV prefix reuse | SGLang RadixAttention (FlexAttention available via --attention-backend flex_attention but page_size must be 1, no FP8 KV cache, no speculative decoding) |
| Research prototype, arbitrary mask pattern | FlexAttention |
| Maximum throughput, fixed mask, B200 | FA4 (hand-tuned kernel outperforms FlexAttention Triton kernel) |
For SGLang deployment and RadixAttention configuration, see the SGLang production deployment guide.
SGLang exposes a FlexAttention backend via --attention-backend flex_attention, but with significant limitations as of mid-2026: page_size must be 1 (no paged KV cache with page_size > 1), no FP8 KV cache, no sliding window, no speculative decoding, and no multimodal support. For production workloads that hit any of those limits, route custom-mask requests through a standalone PyTorch or vLLM backend with FlexAttention enabled, and keep RadixAttention for prefix-reuse-heavy workloads where the mask is standard causal.
Deployment Recipe: PyTorch 2.6 + FlexAttention + vLLM/SGLang
Step 1: Provision and Verify
Provision an H200 or B200 on Spheron. Verify CUDA 12.4+:
nvidia-smi --query-gpu=name,compute_cap --format=csv
# Expected: H200 SXM5, 9.0 or B200 SXM6, 10.0
nvcc --version
# Expected: release 12.4 or laterStep 2: Install PyTorch 2.6
pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124
python -c "
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch
print('FlexAttention available, PyTorch', torch.__version__)
print('CUDA:', torch.version.cuda)
"Step 3: Install vLLM with FlexAttention Backend Support
pip install vllm>=0.6.0vLLM 0.6.0+ includes a FlexAttention backend. Verify:
python -c "from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend; print('ok')" # vLLM >= 0.6 (v1 engine)Step 4: Write a Custom vLLM Attention Backend
For a custom mask pattern (e.g., sliding window with ALiBi):
# my_flex_backend.py
import torch
import functools
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from vllm.v1.attention.backends.registry import register_backend, AttentionBackendEnum
from vllm.v1.attention.backend import AttentionBackend
WINDOW_SIZE = 512
def sw_mask(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx - kv_idx <= WINDOW_SIZE)
def make_alibi_score_mod(slopes_tuple):
def alibi_score_mod(score, b, h, q_idx, kv_idx):
return score - slopes_tuple[h] * (q_idx - kv_idx).abs()
return alibi_score_mod
@functools.lru_cache(maxsize=32)
def get_cached_block_mask(q_len: int, kv_len: int, num_heads: int, device: str):
return create_block_mask(sw_mask, 1, num_heads, q_len, kv_len, device=device)
@register_backend(AttentionBackendEnum.CUSTOM)
class SlidingWindowALiBiBackend(AttentionBackend):
def __init__(self, num_heads: int, head_dim: int, scale: float, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
slopes = tuple(2 ** (-8 * (h + 1) / num_heads) for h in range(num_heads))
self.score_mod = make_alibi_score_mod(slopes)
self._compiled_flex = torch.compile(flex_attention)
def forward(self, query, key, value, kv_cache, attn_metadata, **kwargs):
# This backend handles one sequence at a time (batch_size=1). Batched
# decode with N sequences would require iterating over block_tables per
# sequence and calling flex_attention separately for each; doing so with
# this implementation silently reads only sequence 0's KV cache for all N
# tokens. Raise early so the failure is explicit rather than silent.
if len(attn_metadata.seq_lens) != 1:
raise ValueError(
f"SlidingWindowALiBiBackend only supports batch_size=1 "
f"(got {len(attn_metadata.seq_lens)} sequences). "
"For batched decode, iterate over block_tables per sequence."
)
# Write the current-step key/value tokens into the paged KV cache at their
# allocated slots before reading back the full accumulated sequence.
key_cache, value_cache = kv_cache # each: (num_blocks, num_heads, block_size, head_dim)
torch.ops.vllm.reshape_and_cache_flash(
key, value, key_cache, value_cache,
attn_metadata.slot_mapping, kv_cache_dtype="auto",
)
# Gather the full accumulated key/value sequences from the paged cache.
# During decode, `key`/`value` hold only the single new token, giving shape
# (1, num_heads, 1, head_dim) after reshape. The block_mask is built with
# kv_len = seq_lens[0] (the actual filled length), so k/v match exactly and
# we avoid reading garbage KV entries from the last partially-filled block.
num_tokens = query.shape[0]
# seq_lens[0] is the exact token count for this sequence, unlike max_seq_len
# which pads to the longest sequence in the batch. For a single sequence
# these are equal, but using seq_lens[0] makes the intent explicit.
kv_len = attn_metadata.seq_lens[0]
block_mask = get_cached_block_mask(num_tokens, kv_len, self.num_heads, str(query.device))
# batch_size=1: query holds either the full prompt (prefill) or the single
# new decode token. unsqueeze(0) adds the batch dimension.
q = query.unsqueeze(0).transpose(1, 2) # (1, num_heads, num_tokens, head_dim)
# Reconstruct contiguous (1, num_heads, kv_len, head_dim) tensors from the
# paged cache. block_tables[0] is safe: we asserted a single sequence above.
# Reshape gathers physical blocks into a contiguous buffer; trim to kv_len to
# discard padding in the last partially-filled block.
seq_blocks = attn_metadata.block_tables[0] # (max_blocks_per_seq,)
k = key_cache[seq_blocks].permute(0, 2, 1, 3).reshape(1, -1, self.num_heads, key_cache.shape[-1])
v = value_cache[seq_blocks].permute(0, 2, 1, 3).reshape(1, -1, self.num_heads, value_cache.shape[-1])
k = k[:, :kv_len].permute(0, 2, 1, 3).contiguous() # (1, num_heads, kv_len, head_dim)
v = v[:, :kv_len].permute(0, 2, 1, 3).contiguous()
out = self._compiled_flex(
q, k, v,
score_mod=self.score_mod,
block_mask=block_mask,
)
# Reshape output from (1, num_heads, num_tokens, head_dim) back to
# (num_tokens, num_heads, head_dim) as vLLM expects.
return out.squeeze(0).transpose(0, 1)The @register_backend(AttentionBackendEnum.CUSTOM) decorator registers the class with vLLM's v1 attention registry. Import the module before launching, then pass --attention-backend CUSTOM to select it:
python -c "import my_flex_backend" && \
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70b-instruct \
--attention-backend CUSTOM \
--tensor-parallel-size 2For the full vLLM deployment configuration including CUDA graph tuning and batching, see the vLLM production deployment guide.
Step 5: Profile with PyTorch Profiler
import torch
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=tensorboard_trace_handler("./flex_trace"),
record_shapes=True,
with_stack=True,
) as prof:
for _ in range(10):
out = compiled_flex(Q, K, V, score_mod=score_mod, block_mask=block_mask)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))Look for flex_attention in the CUDA kernel names. If you see scaled_dot_product_attention_flash_attention instead, you hit a fallback to dense SDPA. Causes: unsupported head_dim, None block_mask, or a graph break in score_mod.
For deeper custom kernel integration and Triton kernel debugging, see the OpenAI Triton kernel development guide.
FlexAttention turns any mask pattern into a compiled, FlashAttention-grade kernel. Spheron's bare-metal H200 and B200 instances ship with PyTorch 2.6 prebuilt, so you skip the container fights and go straight to benchmarking.
Quick Setup Guide
Go to app.spheron.ai, select GPU Cloud, and filter for H200 SXM5 or B200 SXM6. Choose an instance with CUDA 12.4 or later. H200 instances are available at $4.62/hr on-demand or $1.49/hr spot. B200 instances start at $7.21/hr on-demand. SSH in and verify the GPU with nvidia-smi.
Run: pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124. Then verify: python -c "from torch.nn.attention.flex_attention import flex_attention, create_block_mask; print('FlexAttention available')". If you see the print output, the install is correct.
Define a pure Python function: def sliding_window(score, b, h, q_idx, kv_idx): return score. Then define the mask: def sw_mask(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (q_idx - kv_idx <= window_size). Use functools.partial or a constant capture (not a mutable tensor capture) for window_size. Call create_block_mask(sw_mask, B, H, SEQ_LEN, SEQ_LEN) to precompute the block-sparse mask, then pass it to flex_attention(Q, K, V, score_mod=score_mod, block_mask=bm).
Wrap your FlexAttention call with torch.compile: compiled_flex = torch.compile(flex_attention). Run a warmup pass (at least 3 forward passes) to trigger kernel compilation, then time with torch.cuda.synchronize() brackets. Compare against F.scaled_dot_product_attention(Q, K, V, is_causal=True) for the same shapes to measure the FlexAttention kernel overhead or speedup.
Subclass vllm.attention.backends.abstract.AttentionBackend and implement forward() to call flex_attention with your compiled block_mask. Register via torch.library.custom_op so Dynamo treats the attention call as an opaque boundary. Launch vLLM with --attention-backend custom --attention-backend-module my_flex_backend. See the Deployment Recipe section for the full class skeleton.
Wrap your inference call with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]). Look for flex_attention entries in the CUDA trace. Check that the kernel name contains 'flex' or 'block_sparse' rather than 'scaled_dot_product_attention_flash_attention', which would indicate a fallback to dense SDPA.
Frequently Asked Questions
FlexAttention is a PyTorch 2.5+ API in torch.nn.attention.flex_attention that lets you define custom attention mask patterns as Python functions (score_mod and mask_mod), which torch.compile compiles into block-sparse FlashAttention-equivalent CUDA kernels via Triton. FlashAttention is a fixed implementation of standard causal or full attention optimized for hardware memory access patterns. FlexAttention gives you FlashAttention-grade kernel efficiency with arbitrary mask logic you write in Python, without needing to implement a CUDA kernel yourself.
FlexAttention requires CUDA-capable Hopper or newer GPUs (compute capability 9.0+) for optimal performance. H100, H200, and B200 GPUs all support it. On Hopper (H100, H200) it generates Triton-based block-sparse kernels. On Blackwell (B200, B300) it falls through to FlashAttention-4's SM100 tile kernels when available. It will run on older GPUs (A100 and earlier) but falls back to dense SDPA without the block-sparse optimization, so the memory savings and speed gains are Hopper-only.
Yes, torch.compile is required to get optimized kernels from FlexAttention. Without torch.compile, score_mod and mask_mod functions run in eager mode as element-wise Python operations, which is far slower than FlashAttention. With torch.compile, Dynamo traces the score_mod and mask_mod functions, generates a fused block-sparse Triton kernel, and produces near-FlashAttention throughput. The score_mod and mask_mod functions must be pure (no Python closures over mutable tensors) to avoid graph breaks.
Use FlashAttention-4 (via vLLM or SGLang auto-detection) for standard causal LLM inference on Blackwell hardware where you do not need custom mask patterns. Use FlexAttention when you have a non-standard mask pattern such as sliding window, ALiBi, document masking, or causal+prefix LM, and want to implement it without writing CUDA. Use SGLang RadixAttention for multi-turn chat and agent workloads where KV cache prefix reuse between requests is the bottleneck, not the attention mask shape.
FlexAttention installs shape guards for each (batch, heads, seq_len, head_dim) tuple. If seq_len changes frequently, torch.compile recompiles on each new shape, which eliminates the speedup and adds latency. Two fixes: (1) use torch._dynamo.mark_dynamic on the seq_len dimension before compiling, which causes Dynamo to generate a symbolic kernel that works across seq_len values, or (2) bucket-pad inputs to fixed sequence lengths such as 512, 1024, 2048, 4096, 8192, and always pass a fixed shape. Option 2 wastes some compute on padding tokens but is simpler to implement in production.
