---
title: Making FlashAttention-4 Faster for Inference
canonical: "https://agenticup.dev/posts/flashattention-4-inference-optimization/"
pubDate: "2026-06-12T00:00:00.000Z"
description: "Modal's engineering team made three targeted changes to FlashAttention-4 that improved inference throughput by up to 4.37x. split KV parallelism, FP8 input support, and arbitrary KV page sizes."
tags: [flashattention-4, inference, optimization, gpu, attention, llm-inference, transformer, cuda]
---

**TL;DR:** I read Modal's FlashAttention-4 post three times. Each read revealed another optimization I had been leaving on the table. Split KV parallelism. FP8 input support. Arbitrary page sizes. None requires a PhD.

Modal made three changes to FlashAttention-4 that improved inference throughput by up to 4.37x. Not new kernels. Just smarter configuration of what already existed.

> **Key takeaways:**
> - Split KV parallelism (PR 1940) improves throughput by up to 4.37x for small query lengths: the single biggest gain for decode-heavy inference
> - FP8 input support (PR 2109) gives up to 1.16x over BF16 with no accuracy loss, while enabling smaller KV caches for longer contexts
> - Arbitrary KV page sizes (PR 2104) improves small page throughput by up to 2.40x: critical for speculative decoding where large pages cause fragmentation
> - All changes are kernel-level, open-source, and require zero model changes
> - The key insight: inference parallelism should target the KV dimension, not the query dimension

FlashAttention-4 brought SM100 tile-based kernels to Blackwell GPUs. It was designed for training: processing full sequences, saturating all available compute, pushing attention throughput to petaflop territory.

Inference is different.

During autoregressive decoding, the query is short, often a single token. The key-value cache is long, tens of thousands of tokens. The attention kernel needs to match a tiny query against a massive KV cache. Most GPU streaming multiprocessors sit idle. In the worst case, up to 75% of them do nothing while a few handle the work.

Modal's engineering team (Charles Frye, Timothy Feng, David Wang) targeted this gap. They made three changes to FlashAttention-4's CUDA implementation that improve inference throughput by up to 4.37x. The changes are in the open-source repository as pull requests: already merged into main.

## 1. Split KV parallelism (PR 1940): up to 4.37x

The original FlashAttention-4 parallelizes across query tiles. For training with full sequences, this works well: there are many query tiles to distribute across SMs.

For inference with a single query token, there's one query tile. One tile, hundreds of SMs. Most sit idle.

The fix: parallelize across the **KV dimension** instead. Split the KV cache into chunks, assign each chunk to a different thread block, compute partial attention for each chunk, then combine the results with a separate reduction kernel (`flash_fwd_combine`).

| KV splits | Memory throughput (TB/s) |
|-----------|-------------------------|
| 1 (baseline) | 0.83 |
| 2 | 2.65 |
| 4 | 4.30 |
| 8 | 4.27 |
| 16 | 4.22 |
| 32 | 4.37 |
| 64 | 4.17 |

The sweet spot is 4-32 splits, shipping 4.22-4.37 TB/s vs 0.83 TB/s baseline: a 4.37x improvement at best. Beyond 32 splits, the reduction kernel overhead starts eating into gains.

This technique is called Flash-Decoding, originally developed for FlashAttention-2. Modal ported it to FA4 with a `num_splits` argument. The implementation is clean: multiple CTAs work per query tile, each on a portion of the sequence, with a separate combine step.

## 2. FP8 input support (PR 2109): up to 1.16x

FlashAttention-4 was designed for BF16 inputs. Modal added support for FP8 (`e4m3` or `e5m2` format), reducing memory and arithmetic bandwidth demand by 2x versus BF16.

| Batch / Seq Len | BF16 TFLOP/s | FP8 TFLOP/s | Speedup |
|-----------------|-------------|-------------|---------|
| 1 / 16384 | 1569 | 1818 | 1.13x |
| 32 / 512 | 962 | 1090 | 1.16x |

The speedup is less than 2x because softmax still runs at higher precision on CUDA cores while Tensor Cores handle the low-precision computation. But 1.13-1.16x is meaningful: and it comes with an additional benefit: **smaller KV caches**.

FP8 KV caches use half the memory of BF16 caches. This means longer context windows or higher batch concurrency for the same memory budget. For inference serving, memory is often the bottleneck. FP8 gives you headroom on both throughput and capacity.

## 3. Arbitrary KV page sizes (PR 1999 + PR 2104): up to 2.40x

This is the most technically interesting change.

The original FlashAttention-4 required KV cache page size to equal tile size, because TMA (Tensor Memory Accelerator) loads require large, affine memory accesses. This is fine for training but causes fragmentation in inference: especially for speculative decoding with short sequences.

Modal's PagedKVManager path uses `cp.async` (without TMA) to support arbitrary page sizes. Early performance was poor for small pages (18.56 TFLOP/s at page_size=1). The fix was a clever **transpose strategy** that decouples address generation from address use:

Instead of each thread computing its own pointer (redundant int64 computation), organize 32 threads in a warp as 4x8 groups. Threads in a column compute sequential row pointers and shuffle the results across the warp. Cross-thread warp shuffle is cheaper than redundant pointer math.

| Page Size | Before (PR 1999) | After (PR 2104) | Speedup |
|-----------|-----------------|-----------------|---------|
| 1 | 18.56 TFLOP/s | 44.57 TFLOP/s | 2.40x |
| 8 | 31.21 TFLOP/s | 42.58 TFLOP/s | 1.37x |
| 32 | 34.98 TFLOP/s | 42.47 TFLOP/s | 1.21x |
| 128 | 42.11 TFLOP/s | 41.96 TFLOP/s |. |

The 2.40x improvement at page_size=1 matters because large page sizes cause internal fragmentation. Speculative decoding generates short sequences: page_size=128 wastes most of each page. With arbitrary page sizes, you can match page size to sequence length, minimizing waste.

## The pattern

These three optimizations share a common insight: **inference parallelism should target the KV dimension, not the query dimension**.

FlashAttention was designed for training: full sequences, maximal compute use. Inference with autoregressive decoding is fundamentally different: tiny queries against massive KV caches. The parallelism strategy that works for training leaves most of the GPU idle during decoding.

Modal's changes are all kernel-level. No model architecture changes. No retraining. No accuracy loss. If you're running FA4 for inference, these improvements are available today in the open-source repository.

The paper accompanying the work (FlashAttention-4: Algorithm and Kernel Pipelining Co-Design) describes the algorithms. The CUDA implementation is in the PRs. Between the two, there's enough detail to apply the same patterns to your own attention kernels: or just update your FlashAttention checkout and measure the difference.

## FAQ

> **Do I need a Blackwell GPU for these optimizations?**
> The split KV parallelism and arbitrary page size changes work on any GPU that runs FlashAttention-4. The FP8 input support benefits most from Blackwell Tensor Cores but the technique is architecture-agnostic. All changes are in the open-source repository.
>
> **Will these optimizations be included in FlashAttention-4 by default?**
> PR 1999, PR 2104, PR 1940, and PR 2109 are merged into the FlashAttention-4 main branch. Rebuild from latest and they're included
>
> **What's the easiest way to use these optimizations?**
> Update your FlashAttention-4 installation to the latest commit. The changes are backward-compatible: existing code benefits without modification. The split KV parallelism is controlled by a `num_splits` argument that defaults to 1 (original behavior) for backward compatibility.
>

## Related Posts

Read [AI agent cost optimization tips](/posts/ai-agent-cost-optimization-tips/) for practical strategies to reduce LLM inference costs in production. Also see [Best open-source LLMs for coding 2026](/posts/best-open-source-llms-coding-2026/) for comparison of models that benefit from attention optimization.



The [Lambda blog](https://lambda.ai/blog/flashattention-4-gives-the-nvidia-blackwell-platform-its-most-improved-attention-kernel-yet) covers the theoretical FLOPs ceiling of FlashAttention-4 on Blackwell hardware. Modal's engineering approach builds on the [FlashAttention-4 Triton discussion](https://www.reddit.com/r/LocalLLaMA/comments/1s1yw23/flashattention4_1613_tflopss_27x_faster_than/) on the LocalLLaMA subreddit.

---

This article was published on Agentic Up (https://agenticup.dev): practical guides for developers and founders building with AI agents. Reach me at hello@agenticup.dev.
