Writing High-Performance Kernels in TileLang, from GEMM to MLA TileLang introduces a middle-ground approach for writing high-performance GPU kernels, offering explicit control over shared memory allocation, pipeline staging, and warp partitioning through Python code while automating layout inference. The framework enables developers to write production-grade kernels like DeepSeek's MLA decode, starting from a simple GEMM example, by explicitly declaring tile-level operations rather than relying on compiler-managed decisions as in Triton or requiring the template-heavy approach of CUTLASS/CuTe. If you write GPU kernels, you live somewhere on a spectrum. At one end is Triton: quick to write, but the compiler makes most of the layout and shared-memory decisions for you. At the other end is CUTLASS / CuTe: total control, at the cost of a lot of template machinery. TileLang sits in the middle. You write Python, but you say explicitly what lives in shared memory, how the pipeline is staged, and how warps split the work — and a layout inference pass fills in the rest. In this post we'll cover the mental model, write a GEMM, and then build up to a real production kernel: DeepSeek's MLA decode, where the interesting decisions actually show up. The goal is not to be exhaustive. It's to show what you think about tiles, and where TileLang quietly does the hard parts for you. We'll finish with a more typical story from production — a kernel where the win wasn't speed at all. Here's the whole idea in three points. block M × block K , say is owned and operated on by a thread block, a warp, or a thread. You stop thinking purely at the thread-block level the way you do in Triton, and you stop hand-managing individual threads the way you do in CUDA. T.alloc shared , what goes to registers T.alloc fragment , and what's thread-local. This is the biggest difference from Triton, which hides shared-memory allocation and staging inside the compiler.If you're coming from Triton, here's the rough mapping. | Triton | TileLang | | |---|---|---| | Granularity | thread block + implicit vectorization | tile block / warp / thread | | Shared memory | managed by the compiler | explicit alloc shared + copy | | Layout | the compiler decides | inferred, but you can annotate | | Pipelining | tl.range + compiler | explicit T.Pipelined num stages= | | Tensor Core | tl.dot | T.gemm with a selectable warp policy | | Backends | NVIDIA mainly / AMD | NVIDIA / AMD / CPU / WebGPU / CuTeDSL, plus Ascend & MUSA forks | The short version: if you want fine control over blocking, pipeline depth, and warp partitioning without writing CUTLASS, TileLang is the sweet spot. For simple elementwise or light fusion, Triton is still quicker to reach for. conda create -n tilelang python=3.10 -y conda activate tilelang pip install tilelang prebuilt wheel, easiest path If you're going to touch the compiler passes, build from source instead you'll need a local LLVM/CUDA toolchain : git clone --recursive https://github.com/tile-ai/tilelang.git cd tilelang && pip install -r requirements-dev.txt pip install -e . -v --no-build-isolation We'll start with the kernel everyone starts with: C = ReLU A @ B . It's small, but it touches every primitive that matters — explicit buffers, parallel copy, software pipelining, the Tensor Core call, and an L2 swizzle. python import tilelang import tilelang.language as T import torch @tilelang.jit def matmul M, N, K, block M, block N, block K, dtype="float16", accum dtype="float" : @T.prim func def matmul relu kernel A: T.Tensor M, K , dtype , B: T.Tensor K, N , dtype , C: T.Tensor M, N , dtype , : grid dims: blocks along N, blocks along M ; 128 threads per block with T.Kernel T.ceildiv N, block N , T.ceildiv M, block M , threads=128 as bx, by : Say where each tile lives, explicitly. A shared = T.alloc shared block M, block K , dtype shared memory B shared = T.alloc shared block K, block N , dtype C local = T.alloc fragment block M, block N , accum dtype register accumulator T.use swizzle panel size=4, order="col" optional: better L2 reuse T.clear C local zero the accumulator for ko in T.Pipelined T.ceildiv K, block K , num stages=3 : T.copy A by block M, ko block K , A shared global - shared T.copy B ko block K, bx block N , B shared T.gemm A shared, B shared, C local tile-level MMA for i, j in T.Parallel block M, block N : fused ReLU C local i, j = T.max C local i, j , 0 T.copy C local, C by block M, bx block N write back return matmul relu kernel M = N = K = 1024 kernel = matmul M, N, K, block M=128, block N=128, block K=64 a = torch.randn M, K, device="cuda", dtype=torch.float16 b = torch.randn K, N, device="cuda", dtype=torch.float16 c = torch.empty M, N, device="cuda", dtype=torch.float16 kernel a, b, c torch.testing.assert close c, torch.relu a @ b , rtol=1e-2, atol=1e-2 print "gemm ok" Here is what each piece is doing. A shared and B shared live in shared memory; C local lives in registers. Accumulator in registers, operands staged through shared memory — that's the standard GEMM recipe, except here T.copy is sugar for a parallel copy. T.Parallel -style move, and the compiler derives a vectorized, coalesced global→shared transfer from it. When the copy sits inside T.Pipelined , it becomes cp.async automatically. T.Pipelined extent, num stages=N is a software pipeline. num stages=3 means triple buffering — while you compute K-tile ko , the loads for ko+1 and ko+2 are already in flight. In Triton, this is a compile flag; here it's just the loop, which is easier to reason about. T.gemm A, B, C is the tile-level matmul. transpose A / transpose B and a policy=T.GemmWarpPolicy. that controls how warps split the output tile. Hold onto that policy argument — it's the whole story when we get to MLA. T.use swizzle The figure below maps all of this onto the hardware. It's worth reading against the code, because the labeled spots are exactly where TileLang hands you control that Triton keeps for itself. Figure: GEMM in TileLang — you place every buffer in the hierarchy yourself. A shared / B shared sit in shared memory, C local accumulates in registers across warps W0–W3, and the K-loop pipeline num stages=3 overlaps cp.async prefetches with the current gemm compute. You can write most kernels with a small vocabulary. T.alloc shared , T.alloc fragment registers , T.alloc local . T.copy src, dst between any two levels; T.clear , T.fill . T.gemm ... ; T.Parallel d0, d1, ... for elementwise loops this is the entry point for layout inference ; T.reduce max / T.reduce sum ; scalar math like T.exp , T.exp2 , T.max , T.infinity . T.Pipelined extent, num stages= , T.use swizzle ... , T.annotate layout ... when you need a specific layout bank-conflict avoidance, custom swizzle . M = T.dynamic "m" so you don't recompile per shape it's called T.symbolic in some versions .Two things you'll want often. To see what the compiler actually emitted: print kernel.get kernel source generated CUDA / HIP And to time it: profiler = kernel.get profiler tensor supply type=tilelang.TensorSupplyType.Normal print f"latency: {profiler.do bench } ms" T.print buf prints a tile from inside the kernel, and the repo's examples/plot layout draws the memory layout, which is handy when you're chasing a bank conflict or checking a swizzle. The GEMM shows the mechanics. This next one shows why they matter. We'll walk through DeepSeek's MLA Multi-Head Latent Attention decode kernel, because it's the cleanest example of TileLang earning its keep. The TileLang reference lands at roughly FlashMLA's H100 performance benchmarked at batch 64/128 in fp16, comfortably ahead of Triton and FlashInfer in about 80 lines of Python. The interesting question is how, because the hard part of MLA isn't the math — it's register pressure. Let's review the loop everyone knows. Every FlashAttention-family kernel has the same shape. Per query block, you stream over key/value blocks and keep a running max and denominator, so the full score matrix never lands in memory: acc s : block M, block N scores for this KV block acc o : block M, dim output accumulator for i in range num kv blocks : acc s = Q @ K i .T m prev = scores max scores max = max m prev, rowmax acc s scores scale = exp m prev - scores max acc o = scores scale rescale prior output acc s = exp acc s - scores max probabilities acc o += acc s @ V i Both acc s and acc o want to stay in registers. For MHA or GQA, that's fine. For MLA, it isn't. Here's where it gets hard. MLA's head dimensions are big: query and key are 576 wide a 512-wide "nope" part with no positional encoding, plus a 64-wide "rope" part , and value is 512. So acc o = block M, 512 , and it has to stay resident in registers across the whole KV loop. Now bring in the hardware. On Hopper, the fast path is wgmma.mma async , which ties 4 warps 128 threads into one warpgroup and requires a minimum M of 64. So the smallest M one warpgroup can own is 64, which means one warpgroup would be holding a 64 × 512 accumulator. That's too big for a single warpgroup's register file. It spills, and performance falls off a cliff. Figure: MLA decode in TileLang — splitting acc o across two warpgroups. WG0 and WG1 each compute Q·K^T policy=FullCol , exchange their score halves through S shared, and then each compute their column slab of P·V into acc o L / acc o R. The whole bookkeeping acc s shape, S shared shape, Q·K split is derived by layout inference from the FullCol policy you annotated. The fix is to split the output across two warpgroups. You can't shrink M below 64, so the only axis left is dim . Use two warpgroups: WG0 owns acc o :, :256 , WG1 owns acc o :, 256: . Now each one holds a 64 × 256 accumulator, which fits. That creates a second problem, though: the P @ V step with policy=FullCol , each warpgroup producing one column slab of the output needs the complete acc s , but in Q @ K each warpgroup only naturally computed half of it. The resolution is a shared-memory swap. During Q @ K , each warpgroup writes its half of acc s to shared memory and reads back the other warpgroup's half, so afterward both hold the full acc s and can each compute their slab of acc o . The diagram above is exactly that: split the scores, swap through S shared , split the output. In CuTe you'd hand-write the layouts, the swizzles, the Tensor Core alignment, and the producer/consumer sync to pull this off. The reason it collapses to ~80 lines here is layout inference. Let's break down what layout inference does. You annotate intent on the T.gemm calls, and it propagates the constraints through the program for you: policy=FullCol on P @ V means each warpgroup needs the full acc s , so acc s = block M, block N . S shared in T.copy S shared, acc s is also block M, block N . Q @ K : with FullCol , each warpgroup's score slab is block M, block N/2 .The key insight is that you never write any of those shapes. You pick the warp policy and write the math; the shapes, the swizzled layouts, and the warp-specialized producer/consumer code all come out of inference. The kernel skeleton. In MLA decode the query splits into a "nope" part Q , dim 512 and a "rope" part Q pe , dim 64 , and the compressed latent serves as both K and V. So the score is a sum of two GEMMs, and the output is one more. The inner loop looks like this a representative skeleton, not line-exact — see example mla decode.py : acc s = Q nope @ KV^T + Q rope @ K pe^T T.gemm Q shared, KV shared, acc s, transpose B=True, policy=T.GemmWarpPolicy.FullCol, clear accum=True T.gemm Q pe shared, K pe shared, acc s, transpose B=True, policy=T.GemmWarpPolicy.FullCol online softmax T.copy scores max, scores max prev T.fill scores max, -T.infinity accum dtype T.reduce max acc s, scores max, dim=1, clear=False ... exp, rescale acc o by scores scale, reduce sum into logsum ... acc o += P @ V V is the same latent KV T.copy acc s, acc s cast T.gemm acc s cast, KV shared, acc o, policy=T.GemmWarpPolicy.FullCol The S shared exchange between the two warpgroups is the part inference inserts for you, once the FullCol policies force acc s to be full per warpgroup. The nice part: the optimizations are one line each. This is where TileLang pays off — the whole performance toolkit is one-liners, and the messy lowering is handled for you. T.use swizzle panel size, order="row" . T.annotate layout {S shared: T.layout.make swizzled layout S shared } — XOR-style address remapping so concurrent accesses spread across banks instead of serializing. mbarrier sync generated. None of it shows up in your code. T.Pipelined range, num stages overlaps loads with compute — more stages, more overlap, but more shared memory, so it's a knob. num split parameter plus a combine kernel.So the genuinely hard reasoning — register budget against the M≥64 floor, who owns what across warpgroups, the shared-memory swap — you express by choosing a policy and writing the math. Everything that would be hundreds of fragile lines in CuTe is inference and codegen. That's the pitch, and MLA is where it's most convincing. The last example is one of our own production kernels at AtlasCloud, from the Wan video-generation VAE on H100/H200. It's a great illustration of the other thing TileLang is excellent at: covering a config a hand-tuned kernel can't reach, with a clean drop-in. The setup. We already ship a hand-tuned fused RMSNorm + SiLU kernel. It's fast, and it's compiled for the hidden dims D ∈ {96, 192, 384} that one model config uses. A newer config needs channel widths like {160, 256, 320, 512, 640, 1024} , so on that config the hand-tuned fast path can't run. We wrote a TileLang drop-in to cover exactly that gap. The TileLang kernel. A drop-in with the same interface BTHWC in/out, same math, same eps that supports any C that's a multiple of 32. Two passes, fully coalesced, FP32 accumulator: python @T.prim func def main X: T.Tensor M, C , dtype , M = B T H W rows gamma: T.Tensor C, , dtype , Y: T.Tensor M, C , dtype : with T.Kernel T.ceildiv M, BLOCK M , threads=128 as bm: X chunk = T.alloc shared BLOCK M, BLOCK C , dtype ss = T.alloc fragment BLOCK M, , accum dtype FP32 sum-of-squares pass 1: loop over C in BLOCK C chunks, accumulate sum of squares in FP32 rinv = rsqrt ss / C + 1e-5 pass 2: re-load X, y = silu x gamma rinv , write back BLOCK C is 128/64/32 depending on C , to respect the TMA boxDim ≤ 256 limit, and the FP32 accumulator keeps the sum of squares from overflowing in FP16. Dispatch keeps the hand-tuned path where it works and only falls back when it has to: python ATLAS SUPPORTED D = 96, 192, 384 def rms silu dispatch x, gamma, out : if x.shape -1 in ATLAS SUPPORTED D: atlas rms norm silu x, gamma, out=out keep the hand-tuned path else: tilelang rms silu bthwc x, gamma, out=out cover the gap What it gained us. All upside, and it's a true drop-in — same interface, same math, same eps , so it slots in behind the existing dispatch with no call-site changes. | What | Gain | |---|---| | Previously-unsupported config | 0 → 1 — it runs now the headline win | | Attention-block RMSNorm vs the eager PyTorch norm it replaced | 42 μs → 20 μs ~2× faster | | End-to-end VAE at production resolution 720×1280, 21 frames | ~1.79× encode, ~1.78× decode | The first row is the real point: TileLang let us serve a model config that previously had no fast path at all, without touching the hand-tuned kernel that already works for the other config. One drop-in, written in Python, and a whole model path went from "throws" to "ships." T.use swizzle , T.annotate layout , T.Pipelined , warp specialization, split-KV — with the lowering handled for you.The cool part of TileLang is that the hard reasoning stays in your head, not in boilerplate. You decide how to split work across warps, where buffers live, and how deep the pipeline runs — and then layout inference and warp specialization turn that into the register layouts, the swizzles, and the producer/consumer dance that would otherwise be hundreds of lines of CuTe. You pick a policy and write the math. That's the whole pitch, and it's why an 80-line MLA kernel can sit next to a hand-tuned CUTLASS one.