{"slug": "making-flashattention-4-faster-for-inference", "title": "Making FlashAttention-4 faster for inference", "summary": "Modal AI engineers Charles Frye and David Wang optimized FlashAttention-4 for large language model inference, focusing on decode-heavy workloads dominated by memory bandwidth-limited token generation. The team adjusted parallelism strategies and replaced Tensor Memory Accelerator loads with cp.async operations to handle variable batch sizes and sequence lengths common in inference. These changes improve kernel performance for production inference workloads, where performance directly impacts product quality.", "body_md": "[Back](/blog)\n\n# Making FlashAttention-4 faster for inference\n\n[Charles Frye@charles_irl](https://twitter.com/charles_irl)\n\n[David Wang@_dcw02](https://twitter.com/_dcw02)\n\nWhen the FlashAttention-4 kernel source was released last year, we dove in and [shared our findings about how the kernel works in excruciating exquisite detail](/blog/reverse-engineer-flash-attention-4). You can now confirm the high-level structure we inferred by reading\n\n[this post](https://www.together.ai/blog/flashattention-4)straight from the horse’s mouth.\n\nIn the intervening months, we’ve made a number of contributions to this kernel to make it more suitable for large language model inference and in particular for decode-heavy workloads. Unlike pre-training workloads, LLM inference workloads are often dominated by the [memory bandwidth-limited](https://modal.com/gpu-glossary/perf/memory-bound) “decode” or “token generation” phase (light blue, below).\n\nInference workloads are also generally more variable — batch sizes and sequence lengths become non-uniform; keys and values must be retrieved from cache (most of the time).\n\nThis requires new kernel code, and that code must be fast: [“performance is the product”](https://modal.com/gpu-glossary/perf).\n\nBefore we dive into the details, some takeaways for a more general audience.\n\n# High-level takeaways about low-level programming\n\nOur changes to extend the kernel to the inference workloads we wanted to run can be lumped into two rough categories:\n\n**adjusting the parallelism strategy**, i.e. the number of query tiles per thread block and switching from query parallelism to key/value parallelism, and** supporting irregular global memory accesses**, i.e.`cp.async`\n\nloads to replace`cp.async.bulk`\n\nloads using the[Tensor Memory Accelerator (TMA)](https://modal.com/gpu-glossary/device-hardware/tensor-memory-accelerator).\n\nThese two categories are represented by the following figures, which are explained in detail below.\n\nAdjusting parallelism strategies gives the largest leverage in improving performance on [modern massively parallel hardware](https://modal.com/gpu-glossary/perf/roofline-model). Intuitively: if you are locked into a specific approach to parallelism, the sequential term in Amdahl’s Law is fixed. If you can change parallelism strategies, you can move work between the parallel and sequential components of your algorithm. This is, per the Law, generally higher leverage than increasing the speed of a fixed parallel component.\n\nWe didn’t choose the [CUDA Templates Domain Specific Language](https://modal.com/gpu-glossary/host-software/cute-dsl) (CuTe DSL), the original kernel authors did, but it worked well for us. It supports highly productive development loops through fast JIT compilation with minimal or zero run-time cost. It also made expressing many of our ideas more straightforward than older tools. Note that because it uses templates, FA4 is really a *family* of kernels, if “kernel” means roughly “something that can be launched into a CUDA stream”. We’ll keep calling it a “kernel”\n\nCuTe DSL was nice. But, as we indicated in [our previous post](/blog/reverse-engineer-flash-attention-4), FA4 is best understood algorithmically at the tile level, not at the [warp](https://modal.com/gpu-glossary/device-software/warp) level at which it is implemented. It’s clear that proper tile-based programming would be better for ergonomics and development speed (which, by the way, [still matters in the age of agents](/blog/agents-devex)). With a tile-based programming model, programmers can more simply express and operate on tile-level flows. That makes it easier to change or add algorithms to kernels at lower engineering cost (the first category of changes). Furthermore, higher-level tile-based models make it easier for compilers to implement and optimize, say, both `cp.async`\n\nand TMA load paths (the second category) and dispatch based on, say, size.\n\nIn this light, we’re very much looking forward to improved support for the [CUDA Tile programming model](https://modal.com/gpu-glossary/device-software/cuda-tile-programming-model), as distinct from the classic [“CUDA SIMT” programming model](https://modal.com/gpu-glossary/device-software/cuda-programming-model), to build the attention and matmul kernels of the future.\n\n# What we did, why, and how we knew it was good\n\nWe organize our contributions by pull request. Each section begins with a “Figure of Merit”: the measurement used to indicate that the contribution improved performance. We report these figures in the traditional format of the performance engineer: an ASCII table.\n\n[PR 2109](https://github.com/Dao-AILab/flash-attention/pull/2109): support FP8 inputs (merged April 17, 2026)\n\n### Figure of Merit: Up to 1.16x throughput relative to bf16 baseline\n\n```\n| Batch Size / Seq Len | BF 16 TFLOP/s | FP8 TFLOP/s | Speedup |\n| -------------------- | ------------- | ----------- | ------- |\n| 1 / 16384            | 1569          | 1818        | 1.13x   |\n| 32 / 512             | 962           | 1090        | 1.16x   |\n```\n\nTraining models generally requires higher precision [floating point numbers](/llm-almanac/quant-formats/) to properly accumulate many small changes inside gradients. But at inference time, we can get away with lower precision. Reducing the bit width by a factor of two reduces memory and arithmetic bandwidth demand by a factor of two without nearly as large a hit to model quality.\n\nThis is especially true of the MLP/MoE layers of large models, which often use diminutive, “nibble”-sized [4 bit floating point numbers](/llm-almanac/quant-formats/0x6). Attention operations, especially on long contexts, involve more accumulations and so are harder to quantize. Models like [ gpt-oss](https://modal.com/docs/examples/gpt_oss_inference) combine\n\n[single-precision](/llm-almanac/quant-formats/bf::0x0380)attention operations with 4 bit matmuls to get the best of both worlds.\n\nHowever, key model families like [DeepSeek-V3 and V4](https://modal.com/docs/examples/deepseek_v4) natively (i.e., from training) support [8 bit](https://modal.com/llm-almanac/quant-formats/e4::0x58) attention operations. And other models like the Qwen and Gemma series are sometimes deployed with 8 bit KV caches to accelerate inference.\n\nSo [we added support](https://github.com/Dao-AILab/flash-attention/pull/2109) for 8 bit floats (with either four or five exponent bits, aka [ e4m3](/llm-almanac/quant-formats/e4::0x38) or\n\n[). Relative to the other changes discussed below, this is pretty unsubtle: fewer bytes moved and operated on means faster inference! It also means smaller KV caches, which means longer contexts and/or increased user concurrency during inference.](/llm-almanac/quant-formats/e5::0x1c)\n\n`e5m2`\n\nNotably, the speedup is less than the 2x you might expect from a 2x reduction in bit width, which cuts demand for both [memory bandwidth](https://modal.com/gpu-glossary/perf/memory-bandwidth) and (effective) [arithmetic bandwidth](https://modal.com/gpu-glossary/perf/arithmetic-bandwidth) by two. Determining the specific [bottleneck](https://modal.com/gpu-glossary/perf/performance-bottleneck) here would require a more detailed analysis. But the result is in line with a bottleneck in the softmax operation, which still operates at the same precision (on [CUDA Cores](https://modal.com/gpu-glossary/device-hardware/cuda-core) and/or [Special Function Units](https://modal.com/gpu-glossary/device-hardware/special-function-unit)) even as the [Tensor Cores](https://modal.com/gpu-glossary/device-hardware/tensor-core) operate on lower-precision inputs.\n\n[PR 1999](https://github.com/Dao-AILab/flash-attention/pull/1999) and [PR 2104](https://github.com/Dao-AILab/flash-attention/pull/2104): support arbitrary KV page sizes (merged November 13, 2025) and optimize performance (merged January 15, 2026)\n\n### Figure of Merit: Up to 2.40x throughput for small page sizes\n\n```\n| Page Size | Added in PR 1999? | TFLOP/s, PR 1999 | TFLOP/s, PR 2104 | Speedup |\n| --------- | ----------------- | ---------------- | ---------------- | ------- |\n| 1         | y                 | 18.56            | 44.57            | 2.40x   |\n| 8         | y                 | 31.21            | 42.58            | 1.37x   |\n| 32        | y                 | 34.98            | 42.47            | 1.21x   |\n| 128       | n                 | 42.11            | 41.96            | -       |\n```\n\nFlashAttention-4 [operates on tiles](/blog/reverse-engineer-flash-attention-4) sized to make effective use of the Blackwell [Tensor Cores](https://modal.com/gpu-glossary/device-hardware/tensor-core). During the decode phase of inference, the tiles for the key and value tensors are constructed out of entries in the KV cache, populated during prefill. In the original version of FlashAttention-4, the KV cache pages needed to be the same size as the tiles.\n\nThis restriction came from the kernel’s use of the [Tensor Memory Accelerator (TMA)](https://modal.com/gpu-glossary/device-hardware/tensor-memory-accelerator), a hardware engine for certain regular memory accesses in GPUs with the Hopper and Blackwell [Streaming Multiprocessor (SM) architecture](https://modal.com/gpu-glossary/device-hardware/streaming-multiprocessor-architecture). The TMA substantially accelerates large affine memory accesses — those that look like “offset plus stride times shape” for many strides, as when accessing via a [CuTe Layout](https://modal.com/gpu-glossary/host-software/cute). This works nicely for accessing\n\n[page-based KV caches](https://arxiv.org/abs/2309.06180)if the page size is large enough.\n\nBut the TMA can’t gather multiple scattered blocks into a single tile in a single load, and it doesn’t speed up (and may slow down) smaller loads, which are a consequence of smaller page sizes.\n\nSo we added a path that uses `cpasync`\n\n, CuTe DSL’s wrapper for [PTX](https://modal.com/gpu-glossary/device-software/parallel-thread-execution) `cp.async`\n\ninstructions, via a `PagedKVManager`\n\n.\n\nIn the TMA-based version, a single [thread](https://modal.com/gpu-glossary/device-software/thread) out of a [warp](https://modal.com/gpu-glossary/device-software/warp) was responsible for loading a tile — the “producer group” in the producer-consumer model is a single thread.\n\nIn the `cpasync`\n\nversion, each thread issues a load (with warps’ loads [coalesced](https://modal.com/gpu-glossary/perf/memory-coalescing) by the hardware), so they calculate their own `page`\n\nand `offset`\n\nwithin the page. This is simple but inefficient; more on that later!\n\nWe repurposed the otherwise idle warp 15 to handle this extra work — the producer group comprises two warps.\n\nIn this first PR, these smaller page sizes had lower arithmetic and memory throughput. But in many inference workloads, KV cache efficiency matters a lot, so this can be a good trade to make.\n\nFirst, large page sizes can lead to unnecessary duplication. If several requests share a prefix of, say, 64 tokens, but differ after that point, an attention kernel with `page_size=128`\n\nwill require a separate page for each request, since the prefix is shorter than the page size. An attention kernel with `page_size=16`\n\ncan share four pages across the requests, reducing the storage required multiplicatively by the number of requests (cf the sharing of the prefix “Thou shalt not” across three requests in the left-hand-side of the figure below, vs its three-fold repetition in the KV cache with larger `page_size`\n\non the right).\n\nLarge page sizes lead to substantial internal fragmentation of the KV cache. Short sequences still require full pages — in the worst case, a single token consumes an entire page that could hold KV cache data for 128 tokens. That’s >99% internal fragmentation for that block. This consumes ~8x the capacity of a `page_size=16`\n\nKV cache which would have “only” 93.75% internal fragmentation.\n\nThis is especially important for speculative decoding. Speculators create many short (~1-16 token) sequences in the KV cache, and with large page sizes, each of those consumes much more space.\n\nSupporting arbitrary page sizes was already a win for compatibility, but the first implementation came at a performance cost. For `page_size=1`\n\n, the most extreme case, memory throughput for [memory-bound](https://modal.com/gpu-glossary/perf/memory-bound) cases of the FA4 kernel was under half the effective [memory bandwidth](https://modal.com/gpu-glossary/perf/memory-bandwidth), and arithmetic throughput for [compute-bound](https://modal.com/gpu-glossary/perf/compute-bound) cases was under one third the effective [arithmetic bandwidth](https://modal.com/gpu-glossary/perf/arithmetic-bandwidth). We fixed the performance in [a follow-up PR](https://github.com/Dao-AILab/flash-attention/pull/2104).\n\nA similar problem affected the FlashAttention-3 kernel, so we ported the strategy over to the FA4 `PagedKVManager`\n\n.\n\nThe key move was decoupling address *generation* from address *use* to reduce redundant computation. This is done by “transposing” address generation, as described below. The approach is also detailed in Section 4.2 in [this paper by Zadouri et al](https://arxiv.org/abs/2505.21487).\n\nWe organize the 32 [threads](https://modal.com/gpu-glossary/device-software/thread) in each [warp](https://modal.com/gpu-glossary/device-software/warp) as an array with four “row” thread groups with eight “columns” of threads each:\n\n```\n                group thread index\n                ------------------------------------------\n                      0    1    2    3    4    5    6    7\n\nwarp thread index     0    1    2    3    4    5    6    7\nwarp thread index     8    9   10   11   12   13   14   15\nwarp thread index    16   17   18   19   20   21   22   23\nwarp thread index    24   25   26   27   28   29   30   31\n```\n\nOur original approach had each thread compute the pointer for the KV cache row that it was also responsible for loading.\n\n```\nloop k (0..7)\n                    row pointer produced by thread\n                    ------------------------------\ngroup 0        4k   4k   4k   4k   4k   4k   4k   4k\ngroup 1        4k+1 4k+1 4k+1 4k+1 4k+1 4k+1 4k+1 4k+1\ngroup 2        4k+2 4k+2 4k+2 4k+2 4k+2 4k+2 4k+2 4k+2\ngroup 3        4k+3 4k+3 4k+3 4k+3 4k+3 4k+3 4k+3 4k+3\n\n                    row loaded by thread\n                    --------------------\ngroup 0        4k   4k   4k   4k   4k   4k   4k   4k\ngroup 1        4k+1 4k+1 4k+1 4k+1 4k+1 4k+1 4k+1 4k+1\ngroup 2        4k+2 4k+2 4k+2 4k+2 4k+2 4k+2 4k+2 4k+2\ngroup 3        4k+3 4k+3 4k+3 4k+3 4k+3 4k+3 4k+3 4k+3\n```\n\nThe load pattern here is constrained by the hardware — to get good [memory coalescence](https://modal.com/gpu-glossary/perf/memory-coalescing), threads should access contiguous memory. With row-wise loads, adjacent threads end up redundantly computing the same row pointer.\n\nUnfortunately, this redundancy is costly. Pointers are 64 bits, and int64 operations are expensive (recent data center GPUs have scaled FLOP and matmul FLOP [arithmetic bandwidth](https://modal.com/gpu-glossary/perf/arithmetic-bandwidth) far more than other op bandwidth). This cost is higher when more addresses need to be calculated, as in smaller page sizes.\n\nThe solution is to produce all 32 row pointers ahead of time, then loop over loads. This introduces a cross-thread synchronization in the form of a warp shuffle, but this is cheaper than the address calculation.\n\nThe specific pattern we use is a transpose: the eight threads in a “row” group in our warp produce row pointers for 1) different rows that 2) are not logically sequential. Instead, threads in a “column” across groups are responsible for computing (but not using) sequential row pointers.\n\n```\n                    row pointer produced by thread\n                    ------------------------------\ngroup 0        0    4    8   12   16   20   24   28\ngroup 1        1    5    9   13   17   21   25   29\ngroup 2        2    6   10   14   18   22   26   30\ngroup 3        3    7   11   15   19   23   27   31\n\nloop k (0..7)\n  group 0 loads row 4k   using pointer produced by thread k\n  group 1 loads row 4k+1 using pointer produced by thread k+8\n  group 2 loads row 4k+2 using pointer produced by thread k+16\n  group 3 loads row 4k+3 using pointer produced by thread k+24\n```\n\nThis improved memory throughput over the old method by up to 2.4x (for `page_size=1`\n\n), achieving the same or greater throughput than what we observed at larger sizes.\n\n[PR 1940](https://github.com/Dao-AILab/flash-attention/pull/1940): add parallelism across the KV dimension (merged November 4, 2025)\n\n### Figure of Merit: Up to 4.37x greater throughput for small query lengths\n\n```\n| Number of KV splits | Memory throughput (TB/s) |\n| ------------------- | ------------------------ |\n| 1 (baseline)        | 0.83                     |\n| 2                   | 2.65                     |\n| 4                   | 4.30                     |\n| 8                   | 4.27                     |\n| 16                  | 4.22                     |\n| 32                  | 4.37                     |\n| 64                  | 4.17                     |\n| 128                 | 3.83                     |\n```\n\nInference performance is generally dominated by decode time. A “typical” inference request spends most of its time producing tokens one or a few at a time based on one or a few queries against many cached KV values.\n\nBut the original FlashAttention-4 kernel architecture parallelized work in the query dimension, not the key/value dimension. For small batch size inference, which is critical for [high-interactivity, latency-sensitive applications](https://modal.com/blog/decagon-case-study), this is kryptonite. The number of distinct parallelizable instances of the kernel program ([cooperative thread arrays](https://modal.com/gpu-glossary/device-software/cooperative-thread-array)) is often much lower than the number of [streaming multiprocessors (SMs)](https://modal.com/gpu-glossary/device-hardware/streaming-multiprocessor), leaving as much as 75% of the SMs idle (faded, in the figure below) and 75% of the GPU’s peak performance on the table. [Without this change](https://github.com/Dao-AILab/flash-attention/pull/1940), FlashAttention-4 was generally slower than FlashAttention-2 on B200s!\n\nThe solution is [Flash-Decoding](https://pytorch.org/blog/flash-decoding/), aka “split KV”, introduced by Tri Dao and collaborators in the FlashAttention-2 era. We ported split KV to FA4 under the argument `num_splits`\n\n. In split KV mode, multiple CTAs work concurrently per query tile, each one computing outputs from a portion of the sequence, followed by a reduction step at the end to produce the final result. The extra reduction step is in a separate kernel, `flash_fwd_combine`\n\n.\n\nSplitting across the KV dimension ensures that there is work for more than one SM, and ideally for all of them.\n\nThe out-of-band reduction introduces numerical differences due to floating point non-associativity. Summing within a split, then across them, gives different results from summing across the flat sequence (another L for the monad bros). In our split path, we do the [shared memory](https://modal.com/gpu-glossary/device-software/shared-memory) output tile accumulation in 32 bit floating point to reduce the impact, but it can’t be eliminated.\n\nThe extra reduction step and its consequences mean that split KV is not always a win. So we added a simple heuristic to detect the optimal number of splits based on SM count and sequence length (triggered via `num_splits = 0`\n\n).\n\n[PR 1993](https://github.com/Dao-AILab/flash-attention/pull/1993/): reduce wasted work for small query sizes (merged January 8, 2026)\n\n### Figure of Merit: Up to 3.06x throughput for single-token decode\n\n```\n| Number of Splits | TFLOP/s | Speedup vs baseline |\n| ---------------- | ------- | ------------------- |\n| 1                | 1.79    | 1.00x               |\n| 2                | 3.39    | 1.89x               |\n| 4                | 5.47    | 3.06x               |\n| 8                | 5.23    | 2.92x               |\n| 16               | 5.08    | 2.84x               |\n| 32               | 5.12    | 2.86x               |\n```\n\nQuery parallelism is not the only choice that reflects the original FlashAttention-4 kernel’s orientation to prefill or training, where there are many query tiles. It was written to operate on two query tiles concurrently, with one dedicated [warpgroup](https://modal.com/gpu-glossary/device-software/warpgroup) of four warps to perform softmax operations for each query tile (eight warps total). Each tile is composed of 128 queries, so this setup assumes at least 256 queries.\n\nBut many attention passes during low-latency inference have far fewer than 256 queries in them, even with speculative decoding and grouped-query/multi-query attention (described below). The query tensors are simply padded with zeros to fill out the remainder, which results in wasted work. In particular, if there are fewer than 128 queries, all of the work on the second tile is unnecessary!\n\nSo we added another path to the core FA4 kernel that operates on only a single query tile at a time (`q_stage = 1`\n\n). This optimization is particularly useful for the short query sequence lengths seen in decode, e.g. `seqlen_q = 1`\n\n.\n\nOperating on only one query tile per block frees up the second softmax [warpgroup](https://modal.com/gpu-glossary/device-software/warpgroup), which normally runs the softmax operations on the second query tile. We repurposed it to run additional KV page loads in the non-TMA/`cpasync`\n\ncase we added in [PR 1999](https://github.com/Dao-AILab/flash-attention/pull/1999), described above.\n\n[PR 2186](https://github.com/Dao-AILab/flash-attention/pull/2186/): speed up irregular Q::KV head ratios by extending GQA packing (merged March 20, 2026)\n\n### Figure of Merit: 2.92x throughput increase for single-token decode\n\n```\n| Pack GQA | TFLOP/s |\n| -------- | ------- |\n| OFF      | 7.1     |\n| ON       | 20.7    |\n```\n\nDecoding doesn’t have to mean running only a single query per sequence. [Grouped-query attention](https://arxiv.org/abs/2305.13245) (GQA) is an architectural variant that applies multiple query vectors per sequence against each KV vector. Like multi-query attention (MQA), the [classic Shazeer jawn](https://arxiv.org/abs/1911.02150) on which it builds, GQA increases the [arithmetic intensity](https://modal.com/gpu-glossary/perf/arithmetic-intensity) of inference.\n\nThere’s a problem: as we’ve discussed, FA4 breaks down the attention computation by *query* — and by default, each query in a GQA group is handled separately. That means the KV values need to be loaded redundantly, negating the intended reduction in memory loads.\n\nThe solution is, of course, to map the group into a single tile — aka “GQA packing”, under the flag `pack_GQA`\n\n. This was already implemented in FA4. But it only worked on certain shapes. Specifically, because this path used TMA loads, it inherited the TMA’s restrictions on alignment and layout. The number of query heads per KV head needed to divide the tile size (128). Some models we wanted to run, like GLM 4.7, didn’t satisfy this constraint.\n\nThe solution was, again, to use `cpasync`\n\nto do normal loads without the TMA, but this time for query tiles instead of KV tiles. The same basic transpose/warp shuffle strategy described for PR 2104 above was already implemented for use with Hopper GPUs, so we just needed to wire the two together.\n\n# Coda\n\nAt Modal, we are all-in on open source software for inference. We are contributing to kernels like FA4, to [inference engines like SGLang](https://modal.com/blog/boosting-multimodal-inference-performance-by-greater-than-10-with-a-single-python-dictionary), and to [training frameworks like SLIME](https://x.com/nanjiangwill/status/2060812693620277526?s=20) because we believe that [our infrastructure](https://modal.com/blog/truly-serverless-gpus) is the best place to deploy this software to production as part of an application, whether that’s [serving inference](/blog/decagon-case-study) or [training models](/blog/reinforcement-learning-infrastructure-problem).\n\nIf you want to contribute to projects like FlashAttention or SGLang — or if you want to build the infrastructure that runs them — we’re [hiring](https://modal.jobs).", "url": "https://wpnews.pro/news/making-flashattention-4-faster-for-inference", "canonical_source": "https://modal.com/blog/flash-attention-4-faster", "published_at": "2026-06-11 23:22:33+00:00", "updated_at": "2026-06-11 23:49:20.414581+00:00", "lang": "en", "topics": ["large-language-models", "ai-infrastructure", "ai-research", "machine-learning", "ai-chips"], "entities": ["FlashAttention-4", "Charles Frye", "David Wang", "Together AI"], "alternates": {"html": "https://wpnews.pro/news/making-flashattention-4-faster-for-inference", "markdown": "https://wpnews.pro/news/making-flashattention-4-faster-for-inference.md", "text": "https://wpnews.pro/news/making-flashattention-4-faster-for-inference.txt", "jsonld": "https://wpnews.pro/news/making-flashattention-4-faster-for-inference.jsonld"}}