{"slug": "a-case-for-tracing-based-dsl-kernel-languages", "title": "A Case for Tracing Based DSL Kernel Languages", "summary": "NVIDIA's C++ template-based CUTLASS library for GPU kernels suffers from compile times of up to 20 seconds for a single kernel and over 17 minutes for full builds, prompting a shift toward Python-embedded DSLs. A tracing-based approach, as opposed to the parsing method used by Triton and CuTe-DSL, offers faster iteration by avoiding template instantiation overhead, with NVIDIA reporting up to 100x compilation speedups for Blackwell GEMM kernels using its CuTe Python DSL.", "body_md": "On the architectural divide between parsing and tracing kernel DSLs, and what tends to go wrong in each.\n\nThe language for writing NVIDIA GPU kernels was always exclusively CUDA, but since [Triton](https://github.com/triton-lang/triton) appeared, a wave of Pythonic DSLs has followed: [CuTe-DSL](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html), [cuTile](https://docs.nvidia.com/cuda/cutile-python/), [Pallas](https://github.com/jax-ml/jax/blob/main/docs/pallas/design/design.md), [Gluon](https://github.com/triton-lang/triton/tree/main/python/triton/experimental/gluon), [Warp](https://github.com/NVIDIA/warp), and the more recent [TileLang](https://github.com/tile-ai/tilelang) used in DeepSeek’s DeepGEMM. Most of these systems share the same goal of lowering a tile-oriented program into PTX or LLVM-IR, and are embedded in Python.\n\nThe question is how to embed the DSL into Python. Triton and CuTe-DSL parse the source AST. Pallas runs the function under abstract values and traces the resulting operations. (PyTorch’s `torch.compile`\n\nintercepts CPython bytecode rather than source, but that is still parsing, just against a smaller, post-desugared grammar; the same trade-offs apply.)\n\nMost DSLs follow Triton’s lead and use *parsing*. This essay takes the alternative and argues that a *tracing*-based approach is often preferable.\n\n## CUDA and Templates\n\nA CUDA kernel directly specifies the execution code for each thread. A textbook fused-softmax kernel in CUDA looks roughly like this:\n\nThe element type `T`\n\nand the block size `BLOCK_SIZE`\n\nmust be known at compile time, as `__shared__`\n\nmemory is statically sized, and the compiler must specialise loop bounds to enable vectorisation of the body. Hence any expansion of the supported configuration space multiplies the number of instantiations. Three element types and four block sizes already imply twelve instantiations, and the responsibility for dispatching among them rests with the caller.\n\nAdding more templates and more generalisations to CUDA, one eventually reaches a heavily templated CUTLASS-like state.\n\n## CUTLASS: Building Blocks for CUDA Kernels\n\nCUTLASS is what C++ template metaprogramming looks like when taken as a way to write GPU kernels. Consider the declaration of its principal `Gemm`\n\nclass, the entry point most users first encounter, from [include/cutlass/gemm/device/gemm.h](https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemm.h):\n\nA fragment of the canonical Hopper warp-specialized GEMM example shows how a user composes a kernel from nested *CollectiveBuilders*, each a template that pulls in dozens of further instantiations:\n\n`examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu`\n\n, lines 83–142.\n\n`Shape<_128,_128,_32>`\n\ndenotes a type rather than a value, and the compiler must instantiate every dependent template once per distinct shape. This results in large compile times.\n\nCompile Time: The Cost of Templates\n\nWe compiled\n\n`48_hopper_warp_specialized_gemm`\n\n, a single CuTe-based GEMM file of roughly 500 lines, with`-c`\n\nand no benchmark harness, inside an`nvidia/cuda:12.5.0-devel-ubuntu22.04`\n\ncontainer, invoking`nvcc -std=c++17 -arch=sm_90a`\n\n. The steady-state nvcc time, averaged over two warm runs on a consumer laptop, was:~20.5 s compiling single kernel for single architecture\n\nA full CUTLASS build targets several architectures, and the cost multiplies accordingly. NVIDIA’s own bug tracker records\n\n[17m22s for two Ampere](issue #1042) and`i16832gemm_s8`\n\nkernels[approximately two minutes for a 30-line CuTe-DSL kernel](issue #2677). The[NVIDIA developer blog post]introducing the CuTe Python DSL in November 2025 frames its principal contribution as “up to two orders of magnitude reduced” compile times relative to C++ CUTLASS, with a quoted “~100x compilation speedup” for Blackwell GEMM and “30-50x” for flash attention.\n\nC++ templates compile too slowly for the iteration speed kernel authors need.\n\n## Triton: DSL Embedded Into Python\n\nUsing Triton is straightforward: decorate a Python function with `@triton.jit`\n\n, mark the compile-time parameters with `tl.constexpr`\n\n, and write the kernel body in something close to NumPy. Triton also simplifies the programming model, focusing on the tile a single thread block operates on, rather than on the code for an individual thread.\n\nTriton is a pleasure to use when it works: the program looks straightforward and does what one would expect. Integration into PyTorch is first-class, there is no build system to set up, and it is relatively hard to construct a malformed program that triggers a crash. However, when one wants to write a reusable generic set of *libraries* in Triton, things get tricky.\n\n## Parsing Limitations\n\nSuppose we want to expose a fused matmul whose *epilogue* (activation, scaling, or fusion applied to the accumulator after the inner product) can be supplied by the caller as a Python callable:\n\nIn a tracing-based framework this is one line: the user hands in a Python callable, and the trace records whatever operations the callable performs. In Triton this is tricky to achieve, and almost every other limitation in this section is a variation of the same underlying constraint.\n\n### What the Ecosystem Actually Does\n\nBefore we look at why, it helps to see how production Triton libraries handle this in practice. Across [Liger-Kernel](https://github.com/linkedin/Liger-Kernel), [FlagGems](https://github.com/flagos-ai/FlagGems), [Quack](https://github.com/Dao-AILab/quack), and [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM), the answer is consistent: enumerate variants statically, dispatch by enum or string tag, never accept a runtime callable. Quack’s GEMM signature is representative:\n\n### Why Lambdas Fail\n\nIf we try to pass a callable in Triton (say, `matmul(A, B, activation=lambda x: tl.where(x > 0, x, 0))`\n\n), it fails at compile time:\n\nFollowing the error’s advice and wrapping the lambda in `triton.jit`\n\nfails earlier, during construction of the `JITFunction`\n\nobject:\n\nThe constructor calls `inspect.getsourcelines(fn)`\n\nand expects the returned source to contain a `def name(`\n\nline, matched via a module-level regex at `jit.py:27`\n\n. It uses the `def`\n\nto compute indentation and dedent the body before handing it to the AST walker; a lambda’s source line contains no `def`\n\n, and the construction step fails. The supported workaround is to lift the activation into a named `@triton.jit def`\n\nand pass *that* through the constexpr argument, which does work. But that is exactly what an enum entry already refers to, and the ecosystem’s convergence on enums reflects this constraint: once the supported shape is “any caller-defined `@triton.jit`\n\n-decorated function passed by name,” the library author may as well enumerate the small set they intend to support, autotune over it, and present a string-typed API.\n\n### Closures\n\nA factory-style helper that captures configuration in an enclosing scope is a common pattern in Python:\n\nIn Triton this fails unconditionally:\n\nTriton does not implement closure capture at all; every free name in a kernel body is resolved against module globals and must already be a `tl.constexpr`\n\n.\n\n### Higher-Order Primitives\n\nThe same constraint shows up most sharply in the region-builder APIs of `tl.reduce`\n\nand `tl.associative_scan`\n\n: even a properly `@triton.jit`\n\n-decorated combine function cannot be passed through a kernel argument at all. The combine must be resolved lexically in the kernel’s enclosing scope, and the standard library copes by shipping one named combine per parameter combination: `_argmax_combine_tie_break_left`\n\nand `_argmax_combine_tie_break_fast`\n\nexist as two file-scoped functions because the boolean argument they differ in cannot be threaded in at call time (`triton/python/triton/language/standard.py:158-165`\n\n). This is arguably more a limitation of the specific region-builder API than of the parsing approach as such; one could imagine a Triton in which `tl.reduce`\n\naccepted callable arguments. But the analogous JAX primitive (`jax.lax.associative_scan(lambda x, y: x + y, xs)`\n\n) does accept a lambda.\n\n### Aren’t Those Fixable?\n\nEach issue described above is, in principle, fixable. But the picture stays the same: supporting metaprogramming in Triton means implementing more and more of Python inside the AST walker.\n\n## CuTe-DSL: The Same Pattern, Sharper Edges\n\nWhile writing CuTe-DSL kernels we encountered two such failures and filed bug reports against both. Each illustrates the pattern:\n\nTwo CuTe-DSL bugs, both from a parsing frontend\n\n—[cutlass#3268]`storage.<field>.get_tensor(...)`\n\nworks at the top level of a kernel but fails with “encountered a user-defined Python object” when placed inside a runtime`if`\n\nblock. The shared-storage tensor lookup is a Python object that cannot be carried across the boundary of the lowered`scf.if`\n\nregion; the surface language does not signal this in any way, so semantically identical code succeeds or fails depending on where it sits in the source. The lowering treats values in “Python” blocks and values in “IR” blocks differently, and the rules for what survives the transition are not part of the user-facing language.\n\n—[cutlass#3266]`nvvm.load.ext`\n\nrejects`BFloat16`\n\n(“Unsupported FP type for ExtLoadOp”) even though`bf16`\n\nis a first-class element type elsewhere in the DSL. Each lowering path re-enumerates the dtypes it knows about, and users reach the cliff every time a new combination of (op, dtype, layout) has not been wired up.\n\nLet’s see how parsing actually works, and how does it impose semantic limitations on the expressiveness of the DSL.\n\n## Under the Hood: Where the Limits Come From\n\nTriton compiles by parsing the body of the decorated function into a Python AST and walking it to emit MLIR. The walker lives in [ python/triton/compiler/code_generator.py](https://github.com/triton-lang/triton/blob/main/python/triton/compiler/code_generator.py), a subclass of\n\n`ast.NodeVisitor`\n\n. Consider visitor methods used to process control flow:A single Python `if`\n\ndenotes two distinct constructs depending on the type of its condition. When `cond`\n\nis a Triton tensor, the walker emits an `scf.if`\n\nregion and both branches are generated. However, when `cond`\n\nis a constexpr, the walker evaluates the condition at compile time, selects one branch, and never visits the other. The same duality governs `for`\n\nloops: `tl.static_range(8)`\n\nunrolls the body into eight copies at parse time, whereas `tl.range(8)`\n\nproduces a single MLIR `scf.for`\n\n. The two are syntactically indistinguishable yet denote entirely different constructs, and the reader has to know which is in force at each site, with no way to force one or the other.\n\n## Pallas: Trace Instead of Parse\n\nJAX’s Pallas kernel language follows JAX’s syntax and uses *tracing*. A Pallas kernel is an ordinary Python function that operates on JAX `Ref`\n\ns and is passed to `pl.pallas_call`\n\n, which executes it under a tracer.\nThe body is not parsed, it is *run* with *tracer* expressions as arguments to\ngenerate a *jaxpr*: an IR capturing the semantics of the kernel.\n\nThere is no separate DSL to parse: the operations are `jnp.max`\n\n, `jnp.exp`\n\n, and `jnp.sum`\n\n, drawn from the NumPy-shaped API that JAX users already know and dispatched through tracers that record into the kernel’s IR rather than executing on the CPU.\nThe Pallas [design document](https://github.com/jax-ml/jax/blob/main/docs/pallas/design/design.md) states the contrast: “JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels.”\n\nThe clearest illustration of what tracing buys comes from [Tokamax](https://github.com/openxla/tokamax)’s SM90 flash-attention kernel, where compile-time and runtime control flow sit side by side in a single body:\n\nThe two regimes are doing distinct jobs.\nThe outer `if is_causal:`\n\nis a plain Python branch evaluated at trace time; the untaken branch is never compiled and never seen by the backend.\nIt selects between two entirely different *runtime* control-flow shapes: the non-causal branch is a single `lax.fori_loop`\n\nover the full KV range, while the causal branch splits the range into a non-causal prefix and a causal tail, each driven by its own `lax.fori_loop`\n\n.\n(The `lax.cond`\n\nwrapping the tail is, per the source comment, a workaround for a compiler bug rather than a deliberate part of the algorithm.)\nThe causal body itself is specialised at trace time via `functools.partial(loop_body, do_causal=True)`\n\n.\nThe backend never has to guess which is which: a Python `if`\n\nis always compile-time, `lax.fori_loop`\n\nis always runtime. The trace *is* the program.\n\n### Runtime Combinators Look Unusual\n\nThe most common complaint about tracing-based DSLs is that runtime control flow looks awkward. Where Triton lets the user write `for k in tl.range(...)`\n\nor `if dynamic_cond:`\n\nin plain Python, Pallas requires `lax.fori_loop(lb, ub, body, carry)`\n\nwith an explicit carry tuple, or a decorated `@pl.when(cond)`\n\nover a freshly-defined nested function. The combinators feel like a step sideways from idiomatic Python at first reading.\n\nAuthors get over this quickly because the patterns are few and soon become second nature. Idiomatic Pallas rarely uses the combinators raw; it wraps each one in a layer of compile-time Python that decides *which* runtime shape to emit, so a single source file can serve causal and non-causal attention by building two specialised `loop_body`\n\nclosures with `functools.partial`\n\nand feeding them to two different `lax.fori_loop`\n\narrangements (exactly what the Tokamax kernel above does). Two recurring patterns are worth naming:\n\n`functools.partial(body, ...static_flag=True)`\n\nto specialise an inner-loop body before passing it to`lax.fori_loop`\n\n.- A decorator-form\n`@pl.when(cond)`\n\nover a fresh`_()`\n\nfunction, so a runtime branch reads like a labelled block.\n\nThe combinators stay small and visible because the Python around them does the structural work first.\n\n### Compile-Time Python as a Metaprogramming Layer\n\nBecause compile-time control flow is just Python, anything Python can compute is available for metaprogramming the kernel, including patterns that a parsing-based DSL would have to special-case in its AST walker.\nThe Pallas `ragged_dot`\n\nMosaic-GPU kernel is a good example.\nThe kernel needs to store a dynamic number of rows (somewhere between 1 and `block_m`\n\n) to global memory, but TMA descriptors require statically-known tile sizes.\nIdiomatic Pallas resolves the contradiction by unrolling a Python `while`\n\nloop into a logarithmic ladder of fixed-size stores, each guarded by a runtime bit-test on the dynamic length:\n\nThe Python `while`\n\nruns at trace time and emits `log2(block_m)`\n\nseparate TMA stores with compile-time constant tile sizes; `@pl.when`\n\nwraps each in a runtime bit-test on the dynamic length, so only the stores whose tile size contributes to the desired total actually fire. The Python loop variable does double duty as both an unroll counter and a bit position. Implementing this in a parsing-based DSL would require either bespoke AST support for Python integer loops that close over kernel refs, or four hand-written `@pl.when`\n\nblocks in the source.\n\nOrdinary Python debugging tooling continues to work inside Pallas kernels. A plain `print`\n\ncall inside a traced function prints once, at trace time, with the tracer’s abstract value, which is useful for inspecting types and shapes during compilation. For runtime values, the deliberately side-effecting primitives go through:\n\nThe tracing approach gets four debugging affordances either free or cheap. Parsing-based DSLs can offer the same, but at considerable engineering cost — Triton, for instance, provides both source-mapped error messages and MLIR `loc`\n\nattributes, but only because the AST walker has been carefully threaded with source locations at every visitor method:\n\n**Native breakpoints.**`jax.debug.breakpoint()`\n\ndrops into a real Python debugger at the corresponding point in the compiled program (`jax/_src/debugger/core.py:160`\n\n); Pallas’s*interpret mode*runs the kernel as a plain Python loop on the CPU, where a bare Python`breakpoint()`\n\nand`pdb`\n\nwork without any special integration.**Native print.** In interpret mode, plain`print()`\n\njust works because the kernel is a normal Python function.`jax.debug.print`\n\nand`pl.debug_print`\n\nextend that to compiled execution.**Full Python tracebacks.** When the trace raises, the traceback is the actual Python call stack into the kernel source (the line of the user’s function that produced the offending operation). Triton points at the kernel source too, via the`def_file_col_number`\n\nmachinery we touched on earlier; tracing simply gets the same for free by being Python.**Op-level provenance in MLIR.** JAX threads a`SourceInfo`\n\nobject (`jax/_src/source_info_util.py:136`\n\n) carrying the originating Python frame through every primitive binding, and the MLIR lowering converts it to a`loc(...)`\n\nattribute via`source_info_to_location`\n\n(`jax/_src/interpreters/mlir.py:524`\n\n). Triton emits`loc`\n\nattributes as well, but only because each visitor method explicitly attaches the source location it is responsible for; tracing gets the annotation as a side effect of every primitive dispatch.\n\n## What Tracing Gives Up\n\nThe largest tracing cost is succinct expressiveness of the control flow.\nEvery Python `if`\n\nis evaluated by the Python interpreter at trace time, so `if some_tracer > 0:`\n\nis illegal: the condition is an abstract value rather than a concrete boolean, and Python raises. Branching on a runtime value therefore requires the explicit combinators `lax.cond`\n\n, `lax.fori_loop`\n\n, or `pl.when`\n\n, and the resulting code is more verbose than its Triton counterpart. A function written in plain Python control flow does not automatically function as a kernel; it must be refactored.\n\nThe Tokamax SM90 attention example earlier in the post shows what this looks like in practice: a pair of `lax.fori_loop`\n\ncalls in the causal branch versus a single one in the non-causal, with a closure-specialised inner body to keep the combinators small. Triton would have written the same algorithm with two ordinary Python `for`\n\nloops and an ordinary `if`\n\n: visually cleaner at the call site, at the cost of the reader having to keep the compile-time / runtime distinction straight in their head. Pallas trades the visual overhead for the syntactic distinction.\n\nOther tracing taxes are less visible but real: shape polymorphism requires explicit machinery (`jax.export`\n\n, polymorphic shape specs) rather than falling out of the source; effectful operations need lifting into JAX primitives; and a kernel parameter that ought to be compile-time but happens to arrive as a concrete Python value during one trace and as a tracer during the next will silently produce two different jaxprs. None of these are insurmountable, but they are real friction.\n\nThe two approaches treat Python itself very differently. Triton overloads `if`\n\nto mean either compile-time or runtime branching depending on the type of the condition, and pays for the overload indefinitely: every reader must re-derive which semantics is in force at each site, and every new Python feature needs a corresponding `visit_X`\n\nmethod. Pallas keeps the two regimes syntactically distinct and is never obliged to implement a visitor for `ast.ListComp`\n\n: a list comprehension simply executes under the Python interpreter, producing either a list of tracers or a list of concrete values, either of which is acceptable.\n\nParsing brings cleaner surface syntax at the cost of making generic libraries painful to write, while tracing accepts syntactic overhead in simple kernels in exchange for that expressiveness.\n\n## Two Toy Implementations\n\nThe architectural difference between the two approaches is compact enough to fit on a single screen. The two implementations below are sketches rather than runnable compilers (they omit a proper symbol table and a real type system), but they make the shape of each strategy concrete. The first is an AST-based mini-DSL.\n\nThe tracing analogue is comparable in size.\n\nThe two implementations differ by only a handful of lines, yet every distinction between them becomes a larger distinction between Triton and Pallas. The AST-based version requires a visitor for every Python construct the user might invoke, whereas the tracing version inherits every Python construct for free provided it terminates in operator dispatch. In the AST version, `if x:`\n\ndemands special-casing inside `visit_If`\n\nfor the constexpr and runtime cases; in the tracing version the same line either succeeds, because `x`\n\nis a concrete Python boolean, or raises, because `x`\n\nis a tracer and the author must instead reach for the explicit `cond(pred, ...)`\n\ncombinator. Both behaviours are useful; only one needs a 1700-line visitor to implement.\n\n## Conclusion\n\nTracing pays the cost of awkward runtime syntax. In exchange it gets a much simpler and more robust compiler implementation, and the ability to express complex generic algorithms through Python metaprogramming, which is often what a reusable set of libraries requires.\n\nSource for the snippets above: Triton commit at [triton-lang/triton](https://github.com/triton-lang/triton), CUTLASS at [NVIDIA/cutlass](https://github.com/NVIDIA/cutlass), JAX at [jax-ml/jax](https://github.com/jax-ml/jax). The two CuTe bugs are [#3266](https://github.com/NVIDIA/cutlass/issues/3266) and [#3268](https://github.com/NVIDIA/cutlass/issues/3268). The compile timing in the receipt box was measured locally in `nvidia/cuda:12.5.0-devel-ubuntu22.04`\n\n; the figure will vary with toolkit version and host configuration. Corrections and counter-examples are welcome by email.", "url": "https://wpnews.pro/news/a-case-for-tracing-based-dsl-kernel-languages", "canonical_source": "https://metaworld.me/blog/public/A-Case-for-Tracing-Based-DSL-Kernel-Languages", "published_at": "2026-05-27 03:10:29+00:00", "updated_at": "2026-05-27 03:27:12.402442+00:00", "lang": "en", "topics": ["ai-infrastructure", "ai-chips", "ai-research", "ai-tools", "machine-learning"], "entities": ["NVIDIA", "Triton", "CuTe-DSL", "cuTile", "Pallas", "Gluon", "Warp", "TileLang"], "alternates": {"html": "https://wpnews.pro/news/a-case-for-tracing-based-dsl-kernel-languages", "markdown": "https://wpnews.pro/news/a-case-for-tracing-based-dsl-kernel-languages.md", "text": "https://wpnews.pro/news/a-case-for-tracing-based-dsl-kernel-languages.txt", "jsonld": "https://wpnews.pro/news/a-case-for-tracing-based-dsl-kernel-languages.jsonld"}}