Tutorial

PyTorch FlexAttention: Custom Attention Patterns in Production (2026 Guide)

Back to BlogWritten by Mitrasish, Co-founderMay 18, 2026
PyTorch FlexAttentionCustom Attention MasksFlexAttention ProductionPyTorch Attention API GPU Cloudscore_mod mask_modtorch.compile attentionLLM InferenceH200B200GPU Cloud
PyTorch FlexAttention: Custom Attention Patterns in Production (2026 Guide)

FlexAttention compiles arbitrary mask patterns to FlashAttention-grade kernels using torch.compile, without any CUDA code. Introduced in PyTorch 2.5 and exposed via torch.nn.attention.flex_attention, it lets you define attention mask logic as plain Python functions that get fused into block-sparse Triton kernels at compile time. The result: sliding-window, ALiBi, document masking, and causal+prefix patterns that run at near-FlashAttention throughput without a single line of CUDA.

This guide covers the score_mod and mask_mod APIs, how torch.compile traces and fuses them, production benchmarks on H100/H200/B200, and a full vLLM integration recipe. For context on the FlashAttention-4 architecture that runs beneath FlexAttention on Blackwell hardware, see the FlashAttention-4 Blackwell inference guide.

If you want to follow along on actual hardware, H200 GPU rental on Spheron is available at $4.62/hr on-demand or $1.49/hr spot, with PyTorch 2.6 and CUDA 12.4 preinstalled.

What FlexAttention Is and Why It Exists

Before FlexAttention, if you needed attention with a non-standard mask, you had three options. You could implement the mask in dense form as a (seq_len, seq_len) matrix and add it to the attention logits before softmax. This works but blows up memory: a 128K-token sequence needs a 16 GB float32 mask matrix. You could fork FlashAttention and rewrite its CUDA kernel. That takes weeks and produces a kernel tied to one specific mask pattern. Or you could use Triton to write a custom fused attention kernel, which is the most flexible option but still requires deep GPU programming knowledge.

FlexAttention takes a different path. You write two Python functions:

  • score_mod(score, b, h, q_idx, kv_idx): a scalar function that receives the raw attention logit for a given query/key pair and returns a modified logit. This runs before the softmax. Use it for relative position penalties, ALiBi slopes, or any per-logit transformation.
  • mask_mod(b, h, q_idx, kv_idx): a boolean predicate that returns True if a given query/key pair should attend to each other and False otherwise. Use it for sliding window, document masking, or causal patterns.

torch.compile traces these functions, infers the block-sparsity pattern from mask_mod, and generates a fused Triton kernel. Blocks where mask_mod is entirely False are skipped, which is where the memory and compute savings come from.

FlexAttention vs Alternatives

BackendWhen to useCUDA requiredCustom masksMemory (128K seqlen)
F.scaled_dot_product_attention (SDPA)Standard causal or full attentionNoLimited (causal/alibi)O(seq^2) for dense mask
FlashAttention-2/3Fixed causal or full attention, HopperVia pip installNoO(seq)
FlashAttention-4Standard attention, Blackwell B200/B300Via pip installNoO(seq)
FlexAttentionCustom mask patterns, Hopper/BlackwellNo (Python only)YesO(seq) via block-sparse
Hand-written Triton kernelFull control, novel architecturesTriton (Python)YesDepends on implementation

The key tradeoff: FlexAttention is slightly slower than FlashAttention-3 on standard causal attention because the Triton kernel generated by torch.compile is not as optimized as the hand-tuned FA3 CUDA kernel. On H200, expect 5-15% slower throughput for pure causal attention. But for non-standard masks, FlexAttention is the right tool because the alternative is a hand-written CUDA kernel that takes weeks and locks you into one pattern.

The score_mod and mask_mod APIs

Sliding Window Attention

Sliding window limits each query to attending only to keys within a fixed distance. This keeps KV cache size linear in window size rather than sequence length, which matters for long-context inference.

python
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

WINDOW_SIZE = 512  # constant, not a mutable tensor

def sliding_window_mask(b, h, q_idx, kv_idx):
    return (q_idx >= kv_idx) & (q_idx - kv_idx <= WINDOW_SIZE)

def identity_score_mod(score, b, h, q_idx, kv_idx):
    return score

B, H, SEQ_LEN, HEAD_DIM = 1, 32, 8192, 128

# Build the block-sparse mask at compile time
# Q_LEN and KV_LEN are separate; for self-attention they are equal
block_mask = create_block_mask(sliding_window_mask, B, H, SEQ_LEN, SEQ_LEN, device="cuda")

Q = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
K = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
V = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)

compiled_flex = torch.compile(flex_attention)
output = compiled_flex(Q, K, V, score_mod=identity_score_mod, block_mask=block_mask)

ALiBi Positional Bias

ALiBi adds a slope-based penalty to each attention logit based on the distance between query and key positions. Different attention heads use different slopes.

python
import math

def make_alibi_score_mod(num_heads: int):
    # Precompute slopes as a tuple of constants, not a mutable tensor
    slopes = tuple(
        2 ** (-8 * (h + 1) / num_heads)
        for h in range(num_heads)
    )

    def alibi_score_mod(score, b, h, q_idx, kv_idx):
        bias = slopes[h] * (q_idx - kv_idx).abs()
        return score - bias

    return alibi_score_mod

alibi_mod = make_alibi_score_mod(num_heads=32)
compiled_flex = torch.compile(flex_attention)
output = compiled_flex(Q, K, V, score_mod=alibi_mod)

The closure over slopes is safe here because slopes is a tuple of Python floats (constants), not a mutable tensor. Closures over mutable tensors cause Dynamo to graph-break.

Document Masking for Packed Sequences

When you pack multiple documents into a single sequence for throughput, you need to prevent attention from crossing document boundaries. The naive approach is to process each document separately, which wastes GPU parallelism.

python
import functools

def make_document_mask(doc_ids_q, doc_ids_kv):
    # doc_ids_q and doc_ids_kv are 1-D int tensors of shape [SEQ_LEN].
    # This is a bare closure over tensors. torch.compile traces it for static
    # document layouts, but may fall back to dense SDPA for dynamic tensor inputs.
    def document_mask(b, h, q_idx, kv_idx):
        return doc_ids_q[q_idx] == doc_ids_kv[kv_idx]
    return document_mask

# Example: two documents packed in one sequence
doc_ids = torch.tensor([0]*512 + [1]*512, device="cuda", dtype=torch.int32)
doc_mask_fn = make_document_mask(doc_ids, doc_ids)

block_mask = create_block_mask(doc_mask_fn, B, H, 1024, 1024, device="cuda")
output = compiled_flex(Q[:, :, :1024, :], K[:, :, :1024, :], V[:, :, :1024, :], block_mask=block_mask)

Causal + Prefix (Prefix-LM)

Prefix-LM models process a fully-visible prefix (e.g., a prompt) followed by a causally-generated suffix. The mask needs to be: full attention within the prefix, causal attention in the suffix, and prefix-to-suffix fully visible.

python
PREFIX_LEN = 256

def prefix_lm_mask(b, h, q_idx, kv_idx):
    prefix_visible = kv_idx < PREFIX_LEN          # prefix is always visible
    causal_in_suffix = kv_idx <= q_idx             # causal for suffix tokens
    return prefix_visible | causal_in_suffix

block_mask = create_block_mask(prefix_lm_mask, B, H, SEQ_LEN, SEQ_LEN, device="cuda")
output = compiled_flex(Q, K, V, block_mask=block_mask)

torch.compile Interaction: Kernels, Fallbacks, and Recompile Traps

How Compilation Works

When you call torch.compile(flex_attention) and then invoke the compiled function, Dynamo traces through the score_mod and mask_mod functions to infer the block-sparsity structure. The output is a fused Triton kernel that:

  1. Reads the block_mask to determine which (Q, K) tile pairs are non-empty.
  2. Skips zero blocks entirely.
  3. Fuses the score_mod transformation into the softmax computation within each non-zero tile.

The kernel is cached per (batch, heads, seq_len, head_dim, device_type) tuple. The first call triggers compilation (30-90 seconds for large models). Subsequent calls with the same shape hit the cache.

When FlexAttention Generates Optimal Kernels

  • score_mod and mask_mod are pure functions with no Python closures over mutable tensors.
  • Input shapes are static (same (B, H, SEQ_LEN, HEAD_DIM) across calls).
  • head_dim is in {64, 128, 256} (as of PyTorch 2.5/2.6; check the upstream PyTorch FlexAttention docs for updates in later versions).
  • block_mask is passed explicitly (not None).

When It Falls Back to Dense SDPA

  • score_mod contains a Python closure over a mutable tensor (not a constant or functools.partial-captured tensor). Dynamo hits a graph break and falls back to element-wise operations.
  • head_dim is not in {64, 128, 256}. FlexAttention logs a warning and calls F.scaled_dot_product_attention instead.
  • block_mask is None. Without the sparsity pattern, FlexAttention runs in dense mode, which is functionally correct but loses all memory savings.

Recompile Thrashing and Dynamic Shapes

The shape guard issue is the most common production problem. Each new (B, H, SEQ_LEN, HEAD_DIM) combination triggers a full recompile. For variable-length inference with many unique sequence lengths:

Option 1: Mark seq_len as dynamic before compiling.

python
Q = torch.randn(B, H, SEQ_LEN, HEAD_DIM, device="cuda", dtype=torch.bfloat16)
torch._dynamo.mark_dynamic(Q, 2)  # dim 2 is seq_len
torch._dynamo.mark_dynamic(K, 2)
torch._dynamo.mark_dynamic(V, 2)
compiled_flex = torch.compile(flex_attention)

Dynamo generates a symbolic kernel that handles a range of seq_len values without recompilation.

Option 2: Bucket-pad to fixed lengths.

python
BUCKETS = [512, 1024, 2048, 4096, 8192, 16384, 32768]

def pad_to_bucket(seq_len: int) -> int:
    for b in BUCKETS:
        if seq_len <= b:
            return b
    return seq_len  # or raise if you want strict control

padded_len = pad_to_bucket(actual_seq_len)

Bucket-padding is simpler to reason about and avoids symbolic shape complexity. At most len(BUCKETS) kernels are compiled and cached. Padding tokens waste some compute but the overhead is usually under 5% for reasonable bucket sizes.

For the broader torch.compile production patterns including CUDA graph capture and Inductor cache persistence, see the torch.compile and CUDA Graphs guide.

Production Benchmarks: H100, H200, B200

Benchmark Methodology

Benchmarks measure prefill throughput (TFLOPs/s) and time-to-first-token (TTFT in ms) for:

  • FlexAttention with sliding-window mask (window=512)
  • Vanilla F.scaled_dot_product_attention with is_causal=True
  • FlashAttention-2 (for H100/H200)

Config: num_heads=32, head_dim=128, batch=1, dtype=bfloat16. torch.cuda.synchronize() brackets around each call, median of 100 runs.

Prefill TTFT at Variable Sequence Length

seq_lenH100 SXM5 SDPA (ms)H100 SXM5 FlexAttn SW (ms)H200 SXM5 SDPA (ms)H200 SXM5 FlexAttn SW (ms)B200 SXM6 FlexAttn SW (ms)
2K4.25.12.93.42.1
8K31.412.820.18.35.2
32K48248.130830.718.9
128KOOM189OOM12174.3

Directional results. Exact numbers vary by model architecture, driver version, and HBM occupancy. FlexAttention's advantage over dense SDPA grows sharply at long sequences because the sliding-window block_mask eliminates most KV pairs, while SDPA must compute and store the full (seq_len, seq_len) attention matrix.

Cost-Per-Hour on Spheron

GPUOn-DemandSpot
H100 SXM5$2.64/hr$1.63/hr
H200 SXM5$4.62/hr$1.49/hr
B200 SXM6$7.21/hr$3.77/hr

Pricing fluctuates based on GPU availability. The prices above are based on 18 May 2026 and may have changed. Check current GPU pricing for live rates.

For long-context workloads (32K+ tokens), the H200 at $4.62/hr on-demand carries higher upfront cost than the H100 at $2.64/hr, but H200's larger HBM3e (141 GB vs 80 GB) allows larger effective batch sizes before OOM, which can offset the price difference for batch jobs. Spot H200 at $1.49/hr is the best option for batch inference jobs that can tolerate preemption.

Real Use Cases

Long-Context Document RAG

When you pack multiple documents into a single batch sequence for retrieval-augmented generation, document masking is essential. Without it, tokens from document A can attend to tokens from document B, which leaks context and degrades retrieval quality.

FlexAttention's mask_mod makes this clean: define doc_ids_q[q_idx] == doc_ids_kv[kv_idx] and pass a precomputed block_mask. The block-sparse kernel skips all inter-document KV pairs without any extra padding or sequence truncation.

At 32K tokens packed across 8 documents, this pattern reduces attention compute by roughly 7/8 compared to dense attention, matching the theoretical sparsity. For KV cache optimization strategies that complement this approach, see the KV cache optimization guide.

Multi-Turn Agent Traces

Agent systems accumulate long conversation histories. Each turn appends to the sequence. For a 10-turn conversation, a causal+prefix mask makes every prior turn fully visible to the current generation step, while preventing the current step from attending to future tokens.

This is more expressive than standard causal masking (which only sees tokens to the left) and cheaper than full attention (which processes all tokens symmetrically). For sequence parallelism patterns that scale this further, see the Ring Attention and Tree Attention guide.

Vision Encoders with Spatial Locality

For ViT-style architectures where image patches have spatial locality (nearby patches are more relevant than distant ones), a 2D sliding window mask reduces attention complexity from O(N^2) to O(N * window^2) per patch. FlexAttention can implement 2D window masks directly:

python
GRID_SIZE = 14  # 14x14 patch grid for a 224px image with 16px patches
W = 3           # window radius in grid units

def vit_window_mask(b, h, q_idx, kv_idx):
    q_row, q_col = q_idx // GRID_SIZE, q_idx % GRID_SIZE
    k_row, k_col = kv_idx // GRID_SIZE, kv_idx % GRID_SIZE
    return ((q_row - k_row).abs() <= W) & ((q_col - k_col).abs() <= W)

Gotchas

Dynamic Shapes and Recompile Thrashing

Covered in the torch.compile section above. The fix is either torch._dynamo.mark_dynamic on the seq_len dimension or bucket-padding to a fixed set of lengths.

head_dim Constraints

FlexAttention's fastest Triton templates target head_dim in {64, 128, 256}. Other values (e.g., head_dim=96 in some GPT-2 variants, or head_dim=48) may compile but will run with reduced performance rather than hitting the optimized block-sparse kernels.

The safe pattern is to check at model init time:

python
SUPPORTED_HEAD_DIMS = {64, 128, 256}
if head_dim not in SUPPORTED_HEAD_DIMS:
    import warnings
    warnings.warn(
        f"head_dim={head_dim} not supported by FlexAttention (supported: {SUPPORTED_HEAD_DIMS}). "
        "Falling back to F.scaled_dot_product_attention."
    )
    use_flex_attention = False

Check the upstream FlexAttention docs for the supported set in your PyTorch version; the supported set has loosened in PyTorch 2.7+.

score_mod Must Be a Pure Function

score_mod and mask_mod must have no Python closures over mutable state. If you capture a torch.Tensor variable in a closure, Dynamo may try to include it as a graph input, which either causes a recompile on every new tensor value or produces incorrect results if Dynamo treats the tensor as a constant.

Safe: closures over Python scalars (int, float), tuples of scalars, or functools.partial-captured integer arguments.

Unsafe: closures over torch.Tensor objects that change between calls.

cache create_block_mask Results

create_block_mask precomputes the sparsity pattern from mask_mod for a given (B, H, Q_LEN, KV_LEN). This computation is non-trivial: it calls mask_mod on a grid of positions, builds the block-dense representation, and allocates the result tensor. For a 32K-token sequence, this takes 10-50 ms.

Cache the BlockMask object and reuse it across calls with the same sequence length:

python
from functools import lru_cache

@lru_cache(maxsize=32)
def get_block_mask(seq_len: int, device: str) -> "BlockMask":
    def sw_mask(b, h, q_idx, kv_idx):
        return (q_idx >= kv_idx) & (q_idx - kv_idx <= 512)
    return create_block_mask(sw_mask, 1, 32, seq_len, seq_len, device=device)

Note that create_block_mask takes Q_LEN and KV_LEN as separate arguments. For self-attention they are equal. For cross-attention (e.g., encoder-decoder where the query comes from the decoder and the key/value come from the encoder), they differ and both must be passed explicitly.

When to Use FlexAttention vs FlashAttention-4 vs SGLang RadixAttention

ScenarioRecommended backend
Standard causal LLM, no custom mask, BlackwellFA4 via vLLM/SGLang auto-detect
Standard causal LLM, no custom mask, HopperFA3 via vLLM/SGLang auto-detect
Custom mask pattern, Hopper (H100, H200)FlexAttention
Custom mask pattern, Blackwell (B200, B300)FlexAttention (falls through to FA4 kernel)
Long-context multi-turn with KV prefix reuseSGLang RadixAttention (FlexAttention available via --attention-backend flex_attention but page_size must be 1, no FP8 KV cache, no speculative decoding)
Research prototype, arbitrary mask patternFlexAttention
Maximum throughput, fixed mask, B200FA4 (hand-tuned kernel outperforms FlexAttention Triton kernel)

For SGLang deployment and RadixAttention configuration, see the SGLang production deployment guide.

SGLang exposes a FlexAttention backend via --attention-backend flex_attention, but with significant limitations as of mid-2026: page_size must be 1 (no paged KV cache with page_size > 1), no FP8 KV cache, no sliding window, no speculative decoding, and no multimodal support. For production workloads that hit any of those limits, route custom-mask requests through a standalone PyTorch or vLLM backend with FlexAttention enabled, and keep RadixAttention for prefix-reuse-heavy workloads where the mask is standard causal.

Deployment Recipe: PyTorch 2.6 + FlexAttention + vLLM/SGLang

Step 1: Provision and Verify

Provision an H200 or B200 on Spheron. Verify CUDA 12.4+:

bash
nvidia-smi --query-gpu=name,compute_cap --format=csv
# Expected: H200 SXM5, 9.0 or B200 SXM6, 10.0
nvcc --version
# Expected: release 12.4 or later

Step 2: Install PyTorch 2.6

bash
pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124
python -c "
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch
print('FlexAttention available, PyTorch', torch.__version__)
print('CUDA:', torch.version.cuda)
"

Step 3: Install vLLM with FlexAttention Backend Support

bash
pip install vllm>=0.6.0

vLLM 0.6.0+ includes a FlexAttention backend. Verify:

bash
python -c "from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend; print('ok')"  # vLLM >= 0.6 (v1 engine)

Step 4: Write a Custom vLLM Attention Backend

For a custom mask pattern (e.g., sliding window with ALiBi):

python
# my_flex_backend.py
import torch
import functools
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from vllm.v1.attention.backends.registry import register_backend, AttentionBackendEnum
from vllm.v1.attention.backend import AttentionBackend

WINDOW_SIZE = 512

def sw_mask(b, h, q_idx, kv_idx):
    return (q_idx >= kv_idx) & (q_idx - kv_idx <= WINDOW_SIZE)

def make_alibi_score_mod(slopes_tuple):
    def alibi_score_mod(score, b, h, q_idx, kv_idx):
        return score - slopes_tuple[h] * (q_idx - kv_idx).abs()
    return alibi_score_mod

@functools.lru_cache(maxsize=32)
def get_cached_block_mask(q_len: int, kv_len: int, num_heads: int, device: str):
    return create_block_mask(sw_mask, 1, num_heads, q_len, kv_len, device=device)

@register_backend(AttentionBackendEnum.CUSTOM)
class SlidingWindowALiBiBackend(AttentionBackend):
    def __init__(self, num_heads: int, head_dim: int, scale: float, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        slopes = tuple(2 ** (-8 * (h + 1) / num_heads) for h in range(num_heads))
        self.score_mod = make_alibi_score_mod(slopes)
        self._compiled_flex = torch.compile(flex_attention)

    def forward(self, query, key, value, kv_cache, attn_metadata, **kwargs):
        # This backend handles one sequence at a time (batch_size=1).  Batched
        # decode with N sequences would require iterating over block_tables per
        # sequence and calling flex_attention separately for each; doing so with
        # this implementation silently reads only sequence 0's KV cache for all N
        # tokens.  Raise early so the failure is explicit rather than silent.
        if len(attn_metadata.seq_lens) != 1:
            raise ValueError(
                f"SlidingWindowALiBiBackend only supports batch_size=1 "
                f"(got {len(attn_metadata.seq_lens)} sequences). "
                "For batched decode, iterate over block_tables per sequence."
            )

        # Write the current-step key/value tokens into the paged KV cache at their
        # allocated slots before reading back the full accumulated sequence.
        key_cache, value_cache = kv_cache  # each: (num_blocks, num_heads, block_size, head_dim)
        torch.ops.vllm.reshape_and_cache_flash(
            key, value, key_cache, value_cache,
            attn_metadata.slot_mapping, kv_cache_dtype="auto",
        )

        # Gather the full accumulated key/value sequences from the paged cache.
        # During decode, `key`/`value` hold only the single new token, giving shape
        # (1, num_heads, 1, head_dim) after reshape.  The block_mask is built with
        # kv_len = seq_lens[0] (the actual filled length), so k/v match exactly and
        # we avoid reading garbage KV entries from the last partially-filled block.
        num_tokens = query.shape[0]
        # seq_lens[0] is the exact token count for this sequence, unlike max_seq_len
        # which pads to the longest sequence in the batch.  For a single sequence
        # these are equal, but using seq_lens[0] makes the intent explicit.
        kv_len = attn_metadata.seq_lens[0]
        block_mask = get_cached_block_mask(num_tokens, kv_len, self.num_heads, str(query.device))

        # batch_size=1: query holds either the full prompt (prefill) or the single
        # new decode token.  unsqueeze(0) adds the batch dimension.
        q = query.unsqueeze(0).transpose(1, 2)  # (1, num_heads, num_tokens, head_dim)

        # Reconstruct contiguous (1, num_heads, kv_len, head_dim) tensors from the
        # paged cache.  block_tables[0] is safe: we asserted a single sequence above.
        # Reshape gathers physical blocks into a contiguous buffer; trim to kv_len to
        # discard padding in the last partially-filled block.
        seq_blocks = attn_metadata.block_tables[0]  # (max_blocks_per_seq,)
        k = key_cache[seq_blocks].permute(0, 2, 1, 3).reshape(1, -1, self.num_heads, key_cache.shape[-1])
        v = value_cache[seq_blocks].permute(0, 2, 1, 3).reshape(1, -1, self.num_heads, value_cache.shape[-1])
        k = k[:, :kv_len].permute(0, 2, 1, 3).contiguous()  # (1, num_heads, kv_len, head_dim)
        v = v[:, :kv_len].permute(0, 2, 1, 3).contiguous()

        out = self._compiled_flex(
            q, k, v,
            score_mod=self.score_mod,
            block_mask=block_mask,
        )
        # Reshape output from (1, num_heads, num_tokens, head_dim) back to
        # (num_tokens, num_heads, head_dim) as vLLM expects.
        return out.squeeze(0).transpose(0, 1)

The @register_backend(AttentionBackendEnum.CUSTOM) decorator registers the class with vLLM's v1 attention registry. Import the module before launching, then pass --attention-backend CUSTOM to select it:

bash
python -c "import my_flex_backend" && \
python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-3-70b-instruct \
  --attention-backend CUSTOM \
  --tensor-parallel-size 2

For the full vLLM deployment configuration including CUDA graph tuning and batching, see the vLLM production deployment guide.

Step 5: Profile with PyTorch Profiler

python
import torch
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    on_trace_ready=tensorboard_trace_handler("./flex_trace"),
    record_shapes=True,
    with_stack=True,
) as prof:
    for _ in range(10):
        out = compiled_flex(Q, K, V, score_mod=score_mod, block_mask=block_mask)
        torch.cuda.synchronize()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

Look for flex_attention in the CUDA kernel names. If you see scaled_dot_product_attention_flash_attention instead, you hit a fallback to dense SDPA. Causes: unsupported head_dim, None block_mask, or a graph break in score_mod.

For deeper custom kernel integration and Triton kernel debugging, see the OpenAI Triton kernel development guide.


FlexAttention turns any mask pattern into a compiled, FlashAttention-grade kernel. Spheron's bare-metal H200 and B200 instances ship with PyTorch 2.6 prebuilt, so you skip the container fights and go straight to benchmarking.

Rent H200 → | Rent B200 → | View GPU Pricing →

STEPS / 06

Quick Setup Guide

  1. Provision an H200 or B200 instance on Spheron

    Go to app.spheron.ai, select GPU Cloud, and filter for H200 SXM5 or B200 SXM6. Choose an instance with CUDA 12.4 or later. H200 instances are available at $4.62/hr on-demand or $1.49/hr spot. B200 instances start at $7.21/hr on-demand. SSH in and verify the GPU with nvidia-smi.

  2. Install PyTorch 2.6 and verify FlexAttention availability

    Run: pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124. Then verify: python -c "from torch.nn.attention.flex_attention import flex_attention, create_block_mask; print('FlexAttention available')". If you see the print output, the install is correct.

  3. Write a sliding-window-with-sink score_mod function

    Define a pure Python function: def sliding_window(score, b, h, q_idx, kv_idx): return score. Then define the mask: def sw_mask(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (q_idx - kv_idx <= window_size). Use functools.partial or a constant capture (not a mutable tensor capture) for window_size. Call create_block_mask(sw_mask, B, H, SEQ_LEN, SEQ_LEN) to precompute the block-sparse mask, then pass it to flex_attention(Q, K, V, score_mod=score_mod, block_mask=bm).

  4. Apply torch.compile and benchmark against SDPA baseline

    Wrap your FlexAttention call with torch.compile: compiled_flex = torch.compile(flex_attention). Run a warmup pass (at least 3 forward passes) to trigger kernel compilation, then time with torch.cuda.synchronize() brackets. Compare against F.scaled_dot_product_attention(Q, K, V, is_causal=True) for the same shapes to measure the FlexAttention kernel overhead or speedup.

  5. Integrate a FlexAttention backend into vLLM

    Subclass vllm.attention.backends.abstract.AttentionBackend and implement forward() to call flex_attention with your compiled block_mask. Register via torch.library.custom_op so Dynamo treats the attention call as an opaque boundary. Launch vLLM with --attention-backend custom --attention-backend-module my_flex_backend. See the Deployment Recipe section for the full class skeleton.

  6. Profile FlexAttention kernels with PyTorch Profiler

    Wrap your inference call with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]). Look for flex_attention entries in the CUDA trace. Check that the kernel name contains 'flex' or 'block_sparse' rather than 'scaled_dot_product_attention_flash_attention', which would indicate a fallback to dense SDPA.

FAQ / 05

Frequently Asked Questions

FlexAttention is a PyTorch 2.5+ API in torch.nn.attention.flex_attention that lets you define custom attention mask patterns as Python functions (score_mod and mask_mod), which torch.compile compiles into block-sparse FlashAttention-equivalent CUDA kernels via Triton. FlashAttention is a fixed implementation of standard causal or full attention optimized for hardware memory access patterns. FlexAttention gives you FlashAttention-grade kernel efficiency with arbitrary mask logic you write in Python, without needing to implement a CUDA kernel yourself.

FlexAttention requires CUDA-capable Hopper or newer GPUs (compute capability 9.0+) for optimal performance. H100, H200, and B200 GPUs all support it. On Hopper (H100, H200) it generates Triton-based block-sparse kernels. On Blackwell (B200, B300) it falls through to FlashAttention-4's SM100 tile kernels when available. It will run on older GPUs (A100 and earlier) but falls back to dense SDPA without the block-sparse optimization, so the memory savings and speed gains are Hopper-only.

Yes, torch.compile is required to get optimized kernels from FlexAttention. Without torch.compile, score_mod and mask_mod functions run in eager mode as element-wise Python operations, which is far slower than FlashAttention. With torch.compile, Dynamo traces the score_mod and mask_mod functions, generates a fused block-sparse Triton kernel, and produces near-FlashAttention throughput. The score_mod and mask_mod functions must be pure (no Python closures over mutable tensors) to avoid graph breaks.

Use FlashAttention-4 (via vLLM or SGLang auto-detection) for standard causal LLM inference on Blackwell hardware where you do not need custom mask patterns. Use FlexAttention when you have a non-standard mask pattern such as sliding window, ALiBi, document masking, or causal+prefix LM, and want to implement it without writing CUDA. Use SGLang RadixAttention for multi-turn chat and agent workloads where KV cache prefix reuse between requests is the bottleneck, not the attention mask shape.

FlexAttention installs shape guards for each (batch, heads, seq_len, head_dim) tuple. If seq_len changes frequently, torch.compile recompiles on each new shape, which eliminates the speedup and adds latency. Two fixes: (1) use torch._dynamo.mark_dynamic on the seq_len dimension before compiling, which causes Dynamo to generate a symbolic kernel that works across seq_len values, or (2) bucket-pad inputs to fixed sequence lengths such as 512, 1024, 2048, 4096, 8192, and always pass a fixed shape. Option 2 wastes some compute on padding tokens but is simpler to implement in production.

Build what's next.

The most cost-effective platform for building, training, and scaling machine learning models-ready when you are.