On the architectural divide between parsing and tracing kernel DSLs, and what tends to go wrong in each.
The language for writing NVIDIA GPU kernels was always exclusively CUDA, but since Triton appeared, a wave of Pythonic DSLs has followed: CuTe-DSL, cuTile, Pallas, Gluon, Warp, and the more recent TileLang used in DeepSeek’s DeepGEMM. Most of these systems share the same goal of lowering a tile-oriented program into PTX or LLVM-IR, and are embedded in Python.
The question is how to embed the DSL into Python. Triton and CuTe-DSL parse the source AST. Pallas runs the function under abstract values and traces the resulting operations. (PyTorch’s torch.compile
intercepts CPython bytecode rather than source, but that is still parsing, just against a smaller, post-desugared grammar; the same trade-offs apply.)
Most DSLs follow Triton’s lead and use parsing. This essay takes the alternative and argues that a tracing-based approach is often preferable.
CUDA and Templates #
A CUDA kernel directly specifies the execution code for each thread. A textbook fused-softmax kernel in CUDA looks roughly like this:
The element type T
and the block size BLOCK_SIZE
must be known at compile time, as __shared__
memory is statically sized, and the compiler must specialise loop bounds to enable vectorisation of the body. Hence any expansion of the supported configuration space multiplies the number of instantiations. Three element types and four block sizes already imply twelve instantiations, and the responsibility for dispatching among them rests with the caller.
Adding more templates and more generalisations to CUDA, one eventually reaches a heavily templated CUTLASS-like state.
CUTLASS: Building Blocks for CUDA Kernels #
CUTLASS is what C++ template metaprogramming looks like when taken as a way to write GPU kernels. Consider the declaration of its principal Gemm
class, the entry point most users first encounter, from include/cutlass/gemm/device/gemm.h:
A fragment of the canonical Hopper warp-specialized GEMM example shows how a user composes a kernel from nested CollectiveBuilders, each a template that pulls in dozens of further instantiations:
examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
, lines 83–142.
Shape<_128,_128,_32>
denotes a type rather than a value, and the compiler must instantiate every dependent template once per distinct shape. This results in large compile times.
Compile Time: The Cost of Templates
We compiled
48_hopper_warp_specialized_gemm
, a single CuTe-based GEMM file of roughly 500 lines, with-c
and no benchmark harness, inside annvidia/cuda:12.5.0-devel-ubuntu22.04
container, invokingnvcc -std=c++17 -arch=sm_90a
. The steady-state nvcc time, averaged over two warm runs on a consumer laptop, was:~20.5 s compiling single kernel for single architecture
A full CUTLASS build targets several architectures, and the cost multiplies accordingly. NVIDIA’s own bug tracker records
[17m22s for two Ampere](issue #1042) andi16832gemm_s8
kernels[approximately two minutes for a 30-line CuTe-DSL kernel](issue #2677). The[NVIDIA developer blog post]introducing the CuTe Python DSL in November 2025 frames its principal contribution as “up to two orders of magnitude reduced” compile times relative to C++ CUTLASS, with a quoted “~100x compilation speedup” for Blackwell GEMM and “30-50x” for flash attention.
C++ templates compile too slowly for the iteration speed kernel authors need.
Triton: DSL Embedded Into Python #
Using Triton is straightforward: decorate a Python function with @triton.jit
, mark the compile-time parameters with tl.constexpr
, and write the kernel body in something close to NumPy. Triton also simplifies the programming model, focusing on the tile a single thread block operates on, rather than on the code for an individual thread.
Triton is a pleasure to use when it works: the program looks straightforward and does what one would expect. Integration into PyTorch is first-class, there is no build system to set up, and it is relatively hard to construct a malformed program that triggers a crash. However, when one wants to write a reusable generic set of libraries in Triton, things get tricky.
Parsing Limitations #
Suppose we want to expose a fused matmul whose epilogue (activation, scaling, or fusion applied to the accumulator after the inner product) can be supplied by the caller as a Python callable:
In a tracing-based framework this is one line: the user hands in a Python callable, and the trace records whatever operations the callable performs. In Triton this is tricky to achieve, and almost every other limitation in this section is a variation of the same underlying constraint.
What the Ecosystem Actually Does
Before we look at why, it helps to see how production Triton libraries handle this in practice. Across Liger-Kernel, FlagGems, Quack, and DeepGEMM, the answer is consistent: enumerate variants statically, dispatch by enum or string tag, never accept a runtime callable. Quack’s GEMM signature is representative:
Why Lambdas Fail
If we try to pass a callable in Triton (say, `matmul(A, B, activation=lambda x: tl.where(x > 0, x, 0))`
), it fails at compile time:
Following the error’s advice and wrapping the lambda in triton.jit
fails earlier, during construction of the JITFunction
object:
The constructor calls inspect.getsourcelines(fn)
and expects the returned source to contain a def name(
line, matched via a module-level regex at jit.py:27
. It uses the def
to compute indentation and dedent the body before handing it to the AST walker; a lambda’s source line contains no def
, and the construction step fails. The supported workaround is to lift the activation into a named @triton.jit def
and pass that through the constexpr argument, which does work. But that is exactly what an enum entry already refers to, and the ecosystem’s convergence on enums reflects this constraint: once the supported shape is “any caller-defined @triton.jit
-decorated function passed by name,” the library author may as well enumerate the small set they intend to support, autotune over it, and present a string-typed API.
Closures
A factory-style helper that captures configuration in an enclosing scope is a common pattern in Python:
In Triton this fails unconditionally:
Triton does not implement closure capture at all; every free name in a kernel body is resolved against module globals and must already be a tl.constexpr
.
Higher-Order Primitives
The same constraint shows up most sharply in the region-builder APIs of tl.reduce
and tl.associative_scan
: even a properly @triton.jit
-decorated combine function cannot be passed through a kernel argument at all. The combine must be resolved lexically in the kernel’s enclosing scope, and the standard library copes by shipping one named combine per parameter combination: _argmax_combine_tie_break_left
and _argmax_combine_tie_break_fast
exist as two file-scoped functions because the boolean argument they differ in cannot be threaded in at call time (triton/python/triton/language/standard.py:158-165
). This is arguably more a limitation of the specific region-builder API than of the parsing approach as such; one could imagine a Triton in which tl.reduce
accepted callable arguments. But the analogous JAX primitive (jax.lax.associative_scan(lambda x, y: x + y, xs)
) does accept a lambda.
Aren’t Those Fixable?
Each issue described above is, in principle, fixable. But the picture stays the same: supporting metaprogramming in Triton means implementing more and more of Python inside the AST walker.
CuTe-DSL: The Same Pattern, Sharper Edges #
While writing CuTe-DSL kernels we encountered two such failures and filed bug reports against both. Each illustrates the pattern: Two CuTe-DSL bugs, both from a parsing frontend
—[cutlass#3268]storage.<field>.get_tensor(...)
works at the top level of a kernel but fails with “encountered a user-defined Python object” when placed inside a runtimeif
block. The shared-storage tensor lookup is a Python object that cannot be carried across the boundary of the loweredscf.if
region; the surface language does not signal this in any way, so semantically identical code succeeds or fails depending on where it sits in the source. The lowering treats values in “Python” blocks and values in “IR” blocks differently, and the rules for what survives the transition are not part of the user-facing language.
—[cutlass#3266]nvvm.load.ext
rejectsBFloat16
(“Unsupported FP type for ExtLoadOp”) even thoughbf16
is a first-class element type elsewhere in the DSL. Each lowering path re-enumerates the dtypes it knows about, and users reach the cliff every time a new combination of (op, dtype, layout) has not been wired up.
Let’s see how parsing actually works, and how does it impose semantic limitations on the expressiveness of the DSL.
Under the Hood: Where the Limits Come From #
Triton compiles by parsing the body of the decorated function into a Python AST and walking it to emit MLIR. The walker lives in python/triton/compiler/code_generator.py, a subclass of
ast.NodeVisitor
. Consider visitor methods used to process control flow:A single Python if
denotes two distinct constructs depending on the type of its condition. When cond
is a Triton tensor, the walker emits an scf.if
region and both branches are generated. However, when cond
is a constexpr, the walker evaluates the condition at compile time, selects one branch, and never visits the other. The same duality governs for
loops: tl.static_range(8)
unrolls the body into eight copies at parse time, whereas tl.range(8)
produces a single MLIR scf.for
. The two are syntactically indistinguishable yet denote entirely different constructs, and the reader has to know which is in force at each site, with no way to force one or the other.
Pallas: Trace Instead of Parse #
JAX’s Pallas kernel language follows JAX’s syntax and uses tracing. A Pallas kernel is an ordinary Python function that operates on JAX Ref
s and is passed to pl.pallas_call
, which executes it under a tracer. The body is not parsed, it is run with tracer expressions as arguments to generate a jaxpr: an IR capturing the semantics of the kernel.
There is no separate DSL to parse: the operations are jnp.max
, jnp.exp
, and jnp.sum
, drawn from the NumPy-shaped API that JAX users already know and dispatched through tracers that record into the kernel’s IR rather than executing on the CPU. The Pallas design document states the contrast: “JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels.”
The clearest illustration of what tracing buys comes from Tokamax’s SM90 flash-attention kernel, where compile-time and runtime control flow sit side by side in a single body:
The two regimes are doing distinct jobs.
The outer if is_causal:
is a plain Python branch evaluated at trace time; the untaken branch is never compiled and never seen by the backend.
It selects between two entirely different runtime control-flow shapes: the non-causal branch is a single lax.fori_loop
over the full KV range, while the causal branch splits the range into a non-causal prefix and a causal tail, each driven by its own lax.fori_loop
.
(The lax.cond
wrapping the tail is, per the source comment, a workaround for a compiler bug rather than a deliberate part of the algorithm.)
The causal body itself is specialised at trace time via functools.partial(loop_body, do_causal=True)
.
The backend never has to guess which is which: a Python if
is always compile-time, lax.fori_loop
is always runtime. The trace is the program.
Runtime Combinators Look Unusual
The most common complaint about tracing-based DSLs is that runtime control flow looks awkward. Where Triton lets the user write for k in tl.range(...)
or if dynamic_cond:
in plain Python, Pallas requires lax.fori_loop(lb, ub, body, carry)
with an explicit carry tuple, or a decorated @pl.when(cond)
over a freshly-defined nested function. The combinators feel like a step sideways from idiomatic Python at first reading.
Authors get over this quickly because the patterns are few and soon become second nature. Idiomatic Pallas rarely uses the combinators raw; it wraps each one in a layer of compile-time Python that decides which runtime shape to emit, so a single source file can serve causal and non-causal attention by building two specialised loop_body
closures with functools.partial
and feeding them to two different lax.fori_loop
arrangements (exactly what the Tokamax kernel above does). Two recurring patterns are worth naming:
functools.partial(body, ...static_flag=True)
to specialise an inner-loop body before passing it tolax.fori_loop
.- A decorator-form
`@pl.when(cond)`
over a fresh`_()`
function, so a runtime branch reads like a labelled block.
The combinators stay small and visible because the Python around them does the structural work first.
Compile-Time Python as a Metaprogramming Layer
Because compile-time control flow is just Python, anything Python can compute is available for metaprogramming the kernel, including patterns that a parsing-based DSL would have to special-case in its AST walker.
The Pallas ragged_dot
Mosaic-GPU kernel is a good example.
The kernel needs to store a dynamic number of rows (somewhere between 1 and block_m
) to global memory, but TMA descriptors require statically-known tile sizes.
Idiomatic Pallas resolves the contradiction by unrolling a Python while
loop into a logarithmic ladder of fixed-size stores, each guarded by a runtime bit-test on the dynamic length:
The Python while
runs at trace time and emits log2(block_m)
separate TMA stores with compile-time constant tile sizes; @pl.when
wraps each in a runtime bit-test on the dynamic length, so only the stores whose tile size contributes to the desired total actually fire. The Python loop variable does double duty as both an unroll counter and a bit position. Implementing this in a parsing-based DSL would require either bespoke AST support for Python integer loops that close over kernel refs, or four hand-written @pl.when
blocks in the source.
Ordinary Python debugging tooling continues to work inside Pallas kernels. A plain print
call inside a traced function prints once, at trace time, with the tracer’s abstract value, which is useful for inspecting types and shapes during compilation. For runtime values, the deliberately side-effecting primitives go through:
The tracing approach gets four debugging affordances either free or cheap. Parsing-based DSLs can offer the same, but at considerable engineering cost — Triton, for instance, provides both source-mapped error messages and MLIR loc
attributes, but only because the AST walker has been carefully threaded with source locations at every visitor method:
Native breakpoints.jax.debug.breakpoint()
drops into a real Python debugger at the corresponding point in the compiled program (jax/_src/debugger/core.py:160
); Pallas’sinterpret moderuns the kernel as a plain Python loop on the CPU, where a bare Pythonbreakpoint()
andpdb
work without any special integration.Native print. In interpret mode, plainprint()
just works because the kernel is a normal Python function.jax.debug.print
andpl.debug_print
extend that to compiled execution.Full Python tracebacks. When the trace raises, the traceback is the actual Python call stack into the kernel source (the line of the user’s function that produced the offending operation). Triton points at the kernel source too, via thedef_file_col_number
machinery we touched on earlier; tracing simply gets the same for free by being Python.Op-level provenance in MLIR. JAX threads aSourceInfo
object (jax/_src/source_info_util.py:136
) carrying the originating Python frame through every primitive binding, and the MLIR lowering converts it to aloc(...)
attribute viasource_info_to_location
(jax/_src/interpreters/mlir.py:524
). Triton emitsloc
attributes as well, but only because each visitor method explicitly attaches the source location it is responsible for; tracing gets the annotation as a side effect of every primitive dispatch.
What Tracing Gives Up #
The largest tracing cost is succinct expressiveness of the control flow.
Every Python if
is evaluated by the Python interpreter at trace time, so if some_tracer > 0:
is illegal: the condition is an abstract value rather than a concrete boolean, and Python raises. Branching on a runtime value therefore requires the explicit combinators lax.cond
, lax.fori_loop
, or pl.when
, and the resulting code is more verbose than its Triton counterpart. A function written in plain Python control flow does not automatically function as a kernel; it must be refactored.
The Tokamax SM90 attention example earlier in the post shows what this looks like in practice: a pair of lax.fori_loop
calls in the causal branch versus a single one in the non-causal, with a closure-specialised inner body to keep the combinators small. Triton would have written the same algorithm with two ordinary Python for
loops and an ordinary if
: visually cleaner at the call site, at the cost of the reader having to keep the compile-time / runtime distinction straight in their head. Pallas trades the visual overhead for the syntactic distinction.
Other tracing taxes are less visible but real: shape polymorphism requires explicit machinery (jax.export
, polymorphic shape specs) rather than falling out of the source; effectful operations need lifting into JAX primitives; and a kernel parameter that ought to be compile-time but happens to arrive as a concrete Python value during one trace and as a tracer during the next will silently produce two different jaxprs. None of these are insurmountable, but they are real friction.
The two approaches treat Python itself very differently. Triton overloads if
to mean either compile-time or runtime branching depending on the type of the condition, and pays for the overload indefinitely: every reader must re-derive which semantics is in force at each site, and every new Python feature needs a corresponding visit_X
method. Pallas keeps the two regimes syntactically distinct and is never obliged to implement a visitor for ast.ListComp
: a list comprehension simply executes under the Python interpreter, producing either a list of tracers or a list of concrete values, either of which is acceptable.
Parsing brings cleaner surface syntax at the cost of making generic libraries painful to write, while tracing accepts syntactic overhead in simple kernels in exchange for that expressiveness.
Two Toy Implementations #
The architectural difference between the two approaches is compact enough to fit on a single screen. The two implementations below are sketches rather than runnable compilers (they omit a proper symbol table and a real type system), but they make the shape of each strategy concrete. The first is an AST-based mini-DSL.
The tracing analogue is comparable in size.
The two implementations differ by only a handful of lines, yet every distinction between them becomes a larger distinction between Triton and Pallas. The AST-based version requires a visitor for every Python construct the user might invoke, whereas the tracing version inherits every Python construct for free provided it terminates in operator dispatch. In the AST version, if x:
demands special-casing inside visit_If
for the constexpr and runtime cases; in the tracing version the same line either succeeds, because x
is a concrete Python boolean, or raises, because x
is a tracer and the author must instead reach for the explicit cond(pred, ...)
combinator. Both behaviours are useful; only one needs a 1700-line visitor to implement.
Conclusion #
Tracing pays the cost of awkward runtime syntax. In exchange it gets a much simpler and more robust compiler implementation, and the ability to express complex generic algorithms through Python metaprogramming, which is often what a reusable set of libraries requires.
Source for the snippets above: Triton commit at triton-lang/triton, CUTLASS at NVIDIA/cutlass, JAX at jax-ml/jax. The two CuTe bugs are #3266 and #3268. The compile timing in the receipt box was measured locally in nvidia/cuda:12.5.0-devel-ubuntu22.04
; the figure will vary with toolkit version and host configuration. Corrections and counter-examples are welcome by email.