cd /news/ai-infrastructure/a-case-for-tracing-based-dsl-kernel-… · home topics ai-infrastructure article
[ARTICLE · art-14808] src=metaworld.me pub= topic=ai-infrastructure verified=true sentiment=· neutral

A Case for Tracing Based DSL Kernel Languages

NVIDIA's C++ template-based CUTLASS library for GPU kernels suffers from compile times of up to 20 seconds for a single kernel and over 17 minutes for full builds, prompting a shift toward Python-embedded DSLs. A tracing-based approach, as opposed to the parsing method used by Triton and CuTe-DSL, offers faster iteration by avoiding template instantiation overhead, with NVIDIA reporting up to 100x compilation speedups for Blackwell GEMM kernels using its CuTe Python DSL.

read16 min publishedMay 27, 2026

On the architectural divide between parsing and tracing kernel DSLs, and what tends to go wrong in each.

The language for writing NVIDIA GPU kernels was always exclusively CUDA, but since Triton appeared, a wave of Pythonic DSLs has followed: CuTe-DSL, cuTile, Pallas, Gluon, Warp, and the more recent TileLang used in DeepSeek’s DeepGEMM. Most of these systems share the same goal of lowering a tile-oriented program into PTX or LLVM-IR, and are embedded in Python.

The question is how to embed the DSL into Python. Triton and CuTe-DSL parse the source AST. Pallas runs the function under abstract values and traces the resulting operations. (PyTorch’s torch.compile

intercepts CPython bytecode rather than source, but that is still parsing, just against a smaller, post-desugared grammar; the same trade-offs apply.)

Most DSLs follow Triton’s lead and use parsing. This essay takes the alternative and argues that a tracing-based approach is often preferable.

CUDA and Templates #

A CUDA kernel directly specifies the execution code for each thread. A textbook fused-softmax kernel in CUDA looks roughly like this:

The element type T

and the block size BLOCK_SIZE

must be known at compile time, as __shared__

memory is statically sized, and the compiler must specialise loop bounds to enable vectorisation of the body. Hence any expansion of the supported configuration space multiplies the number of instantiations. Three element types and four block sizes already imply twelve instantiations, and the responsibility for dispatching among them rests with the caller.

Adding more templates and more generalisations to CUDA, one eventually reaches a heavily templated CUTLASS-like state.

CUTLASS: Building Blocks for CUDA Kernels #

CUTLASS is what C++ template metaprogramming looks like when taken as a way to write GPU kernels. Consider the declaration of its principal Gemm

class, the entry point most users first encounter, from include/cutlass/gemm/device/gemm.h:

A fragment of the canonical Hopper warp-specialized GEMM example shows how a user composes a kernel from nested CollectiveBuilders, each a template that pulls in dozens of further instantiations:

examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu

, lines 83–142.

Shape<_128,_128,_32>

denotes a type rather than a value, and the compiler must instantiate every dependent template once per distinct shape. This results in large compile times.

Compile Time: The Cost of Templates

We compiled

48_hopper_warp_specialized_gemm

, a single CuTe-based GEMM file of roughly 500 lines, with-c

and no benchmark harness, inside annvidia/cuda:12.5.0-devel-ubuntu22.04

container, invokingnvcc -std=c++17 -arch=sm_90a . The steady-state nvcc time, averaged over two warm runs on a consumer laptop, was:~20.5 s compiling single kernel for single architecture

A full CUTLASS build targets several architectures, and the cost multiplies accordingly. NVIDIA’s own bug tracker records

[17m22s for two Ampere](issue #1042) andi16832gemm_s8 kernels[approximately two minutes for a 30-line CuTe-DSL kernel](issue #2677). The[NVIDIA developer blog post]introducing the CuTe Python DSL in November 2025 frames its principal contribution as “up to two orders of magnitude reduced” compile times relative to C++ CUTLASS, with a quoted “~100x compilation speedup” for Blackwell GEMM and “30-50x” for flash attention.

C++ templates compile too slowly for the iteration speed kernel authors need.

Triton: DSL Embedded Into Python #

Using Triton is straightforward: decorate a Python function with @triton.jit , mark the compile-time parameters with tl.constexpr

, and write the kernel body in something close to NumPy. Triton also simplifies the programming model, focusing on the tile a single thread block operates on, rather than on the code for an individual thread.

Triton is a pleasure to use when it works: the program looks straightforward and does what one would expect. Integration into PyTorch is first-class, there is no build system to set up, and it is relatively hard to construct a malformed program that triggers a crash. However, when one wants to write a reusable generic set of libraries in Triton, things get tricky.

Parsing Limitations #

Suppose we want to expose a fused matmul whose epilogue (activation, scaling, or fusion applied to the accumulator after the inner product) can be supplied by the caller as a Python callable:

In a tracing-based framework this is one line: the user hands in a Python callable, and the trace records whatever operations the callable performs. In Triton this is tricky to achieve, and almost every other limitation in this section is a variation of the same underlying constraint.

What the Ecosystem Actually Does

Before we look at why, it helps to see how production Triton libraries handle this in practice. Across Liger-Kernel, FlagGems, Quack, and DeepGEMM, the answer is consistent: enumerate variants statically, dispatch by enum or string tag, never accept a runtime callable. Quack’s GEMM signature is representative:

Why Lambdas Fail

If we try to pass a callable in Triton (say, `matmul(A, B, activation=lambda x: tl.where(x > 0, x, 0))`

), it fails at compile time:

Following the error’s advice and wrapping the lambda in triton.jit

fails earlier, during construction of the JITFunction

object:

The constructor calls inspect.getsourcelines(fn)

and expects the returned source to contain a def name(

line, matched via a module-level regex at jit.py:27

. It uses the def

to compute indentation and dedent the body before handing it to the AST walker; a lambda’s source line contains no def

, and the construction step fails. The supported workaround is to lift the activation into a named @triton.jit def

and pass that through the constexpr argument, which does work. But that is exactly what an enum entry already refers to, and the ecosystem’s convergence on enums reflects this constraint: once the supported shape is “any caller-defined @triton.jit

-decorated function passed by name,” the library author may as well enumerate the small set they intend to support, autotune over it, and present a string-typed API.

Closures

A factory-style helper that captures configuration in an enclosing scope is a common pattern in Python:

In Triton this fails unconditionally:

Triton does not implement closure capture at all; every free name in a kernel body is resolved against module globals and must already be a tl.constexpr

.

Higher-Order Primitives

The same constraint shows up most sharply in the region-builder APIs of tl.reduce

and tl.associative_scan

: even a properly @triton.jit

-decorated combine function cannot be passed through a kernel argument at all. The combine must be resolved lexically in the kernel’s enclosing scope, and the standard library copes by shipping one named combine per parameter combination: _argmax_combine_tie_break_left

and _argmax_combine_tie_break_fast

exist as two file-scoped functions because the boolean argument they differ in cannot be threaded in at call time (triton/python/triton/language/standard.py:158-165

). This is arguably more a limitation of the specific region-builder API than of the parsing approach as such; one could imagine a Triton in which tl.reduce

accepted callable arguments. But the analogous JAX primitive (jax.lax.associative_scan(lambda x, y: x + y, xs)

) does accept a lambda.

Aren’t Those Fixable?

Each issue described above is, in principle, fixable. But the picture stays the same: supporting metaprogramming in Triton means implementing more and more of Python inside the AST walker.

CuTe-DSL: The Same Pattern, Sharper Edges #

While writing CuTe-DSL kernels we encountered two such failures and filed bug reports against both. Each illustrates the pattern: Two CuTe-DSL bugs, both from a parsing frontend

—[cutlass#3268]storage.<field>.get_tensor(...) works at the top level of a kernel but fails with “encountered a user-defined Python object” when placed inside a runtimeif

block. The shared-storage tensor lookup is a Python object that cannot be carried across the boundary of the loweredscf.if

region; the surface language does not signal this in any way, so semantically identical code succeeds or fails depending on where it sits in the source. The lowering treats values in “Python” blocks and values in “IR” blocks differently, and the rules for what survives the transition are not part of the user-facing language.

—[cutlass#3266]nvvm.load.ext rejectsBFloat16

(“Unsupported FP type for ExtLoadOp”) even thoughbf16

is a first-class element type elsewhere in the DSL. Each lowering path re-enumerates the dtypes it knows about, and users reach the cliff every time a new combination of (op, dtype, layout) has not been wired up.

Let’s see how parsing actually works, and how does it impose semantic limitations on the expressiveness of the DSL.

Under the Hood: Where the Limits Come From #

Triton compiles by parsing the body of the decorated function into a Python AST and walking it to emit MLIR. The walker lives in python/triton/compiler/code_generator.py, a subclass of

ast.NodeVisitor

. Consider visitor methods used to process control flow:A single Python if

denotes two distinct constructs depending on the type of its condition. When cond

is a Triton tensor, the walker emits an scf.if

region and both branches are generated. However, when cond

is a constexpr, the walker evaluates the condition at compile time, selects one branch, and never visits the other. The same duality governs for

loops: tl.static_range(8) unrolls the body into eight copies at parse time, whereas tl.range(8)

produces a single MLIR scf.for

. The two are syntactically indistinguishable yet denote entirely different constructs, and the reader has to know which is in force at each site, with no way to force one or the other.

Pallas: Trace Instead of Parse #

JAX’s Pallas kernel language follows JAX’s syntax and uses tracing. A Pallas kernel is an ordinary Python function that operates on JAX Ref

s and is passed to pl.pallas_call

, which executes it under a tracer. The body is not parsed, it is run with tracer expressions as arguments to generate a jaxpr: an IR capturing the semantics of the kernel.

There is no separate DSL to parse: the operations are jnp.max

, jnp.exp

, and jnp.sum

, drawn from the NumPy-shaped API that JAX users already know and dispatched through tracers that record into the kernel’s IR rather than executing on the CPU. The Pallas design document states the contrast: “JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels.”

The clearest illustration of what tracing buys comes from Tokamax’s SM90 flash-attention kernel, where compile-time and runtime control flow sit side by side in a single body:

The two regimes are doing distinct jobs. The outer if is_causal:

is a plain Python branch evaluated at trace time; the untaken branch is never compiled and never seen by the backend. It selects between two entirely different runtime control-flow shapes: the non-causal branch is a single lax.fori_loop

over the full KV range, while the causal branch splits the range into a non-causal prefix and a causal tail, each driven by its own lax.fori_loop

. (The lax.cond

wrapping the tail is, per the source comment, a workaround for a compiler bug rather than a deliberate part of the algorithm.) The causal body itself is specialised at trace time via functools.partial(loop_body, do_causal=True)

. The backend never has to guess which is which: a Python if

is always compile-time, lax.fori_loop

is always runtime. The trace is the program.

Runtime Combinators Look Unusual

The most common complaint about tracing-based DSLs is that runtime control flow looks awkward. Where Triton lets the user write for k in tl.range(...)

or if dynamic_cond:

in plain Python, Pallas requires lax.fori_loop(lb, ub, body, carry)

with an explicit carry tuple, or a decorated @pl.when(cond)

over a freshly-defined nested function. The combinators feel like a step sideways from idiomatic Python at first reading.

Authors get over this quickly because the patterns are few and soon become second nature. Idiomatic Pallas rarely uses the combinators raw; it wraps each one in a layer of compile-time Python that decides which runtime shape to emit, so a single source file can serve causal and non-causal attention by building two specialised loop_body

closures with functools.partial

and feeding them to two different lax.fori_loop

arrangements (exactly what the Tokamax kernel above does). Two recurring patterns are worth naming:

functools.partial(body, ...static_flag=True) to specialise an inner-loop body before passing it tolax.fori_loop

.- A decorator-form
`@pl.when(cond)`

over a fresh`_()`

function, so a runtime branch reads like a labelled block.

The combinators stay small and visible because the Python around them does the structural work first.

Compile-Time Python as a Metaprogramming Layer

Because compile-time control flow is just Python, anything Python can compute is available for metaprogramming the kernel, including patterns that a parsing-based DSL would have to special-case in its AST walker. The Pallas ragged_dot

Mosaic-GPU kernel is a good example. The kernel needs to store a dynamic number of rows (somewhere between 1 and block_m

) to global memory, but TMA descriptors require statically-known tile sizes. Idiomatic Pallas resolves the contradiction by unrolling a Python while

loop into a logarithmic ladder of fixed-size stores, each guarded by a runtime bit-test on the dynamic length:

The Python while

runs at trace time and emits log2(block_m)

separate TMA stores with compile-time constant tile sizes; @pl.when

wraps each in a runtime bit-test on the dynamic length, so only the stores whose tile size contributes to the desired total actually fire. The Python loop variable does double duty as both an unroll counter and a bit position. Implementing this in a parsing-based DSL would require either bespoke AST support for Python integer loops that close over kernel refs, or four hand-written @pl.when

blocks in the source.

Ordinary Python debugging tooling continues to work inside Pallas kernels. A plain print

call inside a traced function prints once, at trace time, with the tracer’s abstract value, which is useful for inspecting types and shapes during compilation. For runtime values, the deliberately side-effecting primitives go through:

The tracing approach gets four debugging affordances either free or cheap. Parsing-based DSLs can offer the same, but at considerable engineering cost — Triton, for instance, provides both source-mapped error messages and MLIR loc

attributes, but only because the AST walker has been carefully threaded with source locations at every visitor method:

Native breakpoints.jax.debug.breakpoint()

drops into a real Python debugger at the corresponding point in the compiled program (jax/_src/debugger/core.py:160

); Pallas’sinterpret moderuns the kernel as a plain Python loop on the CPU, where a bare Pythonbreakpoint()

andpdb

work without any special integration.Native print. In interpret mode, plainprint()

just works because the kernel is a normal Python function.jax.debug.print

andpl.debug_print

extend that to compiled execution.Full Python tracebacks. When the trace raises, the traceback is the actual Python call stack into the kernel source (the line of the user’s function that produced the offending operation). Triton points at the kernel source too, via thedef_file_col_number

machinery we touched on earlier; tracing simply gets the same for free by being Python.Op-level provenance in MLIR. JAX threads aSourceInfo

object (jax/_src/source_info_util.py:136

) carrying the originating Python frame through every primitive binding, and the MLIR lowering converts it to aloc(...)

attribute viasource_info_to_location

(jax/_src/interpreters/mlir.py:524 ). Triton emitsloc

attributes as well, but only because each visitor method explicitly attaches the source location it is responsible for; tracing gets the annotation as a side effect of every primitive dispatch.

What Tracing Gives Up #

The largest tracing cost is succinct expressiveness of the control flow. Every Python if

is evaluated by the Python interpreter at trace time, so if some_tracer > 0:

is illegal: the condition is an abstract value rather than a concrete boolean, and Python raises. Branching on a runtime value therefore requires the explicit combinators lax.cond

, lax.fori_loop

, or pl.when

, and the resulting code is more verbose than its Triton counterpart. A function written in plain Python control flow does not automatically function as a kernel; it must be refactored.

The Tokamax SM90 attention example earlier in the post shows what this looks like in practice: a pair of lax.fori_loop

calls in the causal branch versus a single one in the non-causal, with a closure-specialised inner body to keep the combinators small. Triton would have written the same algorithm with two ordinary Python for

loops and an ordinary if

: visually cleaner at the call site, at the cost of the reader having to keep the compile-time / runtime distinction straight in their head. Pallas trades the visual overhead for the syntactic distinction.

Other tracing taxes are less visible but real: shape polymorphism requires explicit machinery (jax.export

, polymorphic shape specs) rather than falling out of the source; effectful operations need lifting into JAX primitives; and a kernel parameter that ought to be compile-time but happens to arrive as a concrete Python value during one trace and as a tracer during the next will silently produce two different jaxprs. None of these are insurmountable, but they are real friction.

The two approaches treat Python itself very differently. Triton overloads if

to mean either compile-time or runtime branching depending on the type of the condition, and pays for the overload indefinitely: every reader must re-derive which semantics is in force at each site, and every new Python feature needs a corresponding visit_X

method. Pallas keeps the two regimes syntactically distinct and is never obliged to implement a visitor for ast.ListComp

: a list comprehension simply executes under the Python interpreter, producing either a list of tracers or a list of concrete values, either of which is acceptable.

Parsing brings cleaner surface syntax at the cost of making generic libraries painful to write, while tracing accepts syntactic overhead in simple kernels in exchange for that expressiveness.

Two Toy Implementations #

The architectural difference between the two approaches is compact enough to fit on a single screen. The two implementations below are sketches rather than runnable compilers (they omit a proper symbol table and a real type system), but they make the shape of each strategy concrete. The first is an AST-based mini-DSL.

The tracing analogue is comparable in size.

The two implementations differ by only a handful of lines, yet every distinction between them becomes a larger distinction between Triton and Pallas. The AST-based version requires a visitor for every Python construct the user might invoke, whereas the tracing version inherits every Python construct for free provided it terminates in operator dispatch. In the AST version, if x:

demands special-casing inside visit_If

for the constexpr and runtime cases; in the tracing version the same line either succeeds, because x is a concrete Python boolean, or raises, because x

is a tracer and the author must instead reach for the explicit cond(pred, ...)

combinator. Both behaviours are useful; only one needs a 1700-line visitor to implement.

Conclusion #

Tracing pays the cost of awkward runtime syntax. In exchange it gets a much simpler and more robust compiler implementation, and the ability to express complex generic algorithms through Python metaprogramming, which is often what a reusable set of libraries requires.

Source for the snippets above: Triton commit at triton-lang/triton, CUTLASS at NVIDIA/cutlass, JAX at jax-ml/jax. The two CuTe bugs are #3266 and #3268. The compile timing in the receipt box was measured locally in nvidia/cuda:12.5.0-devel-ubuntu22.04

; the figure will vary with toolkit version and host configuration. Corrections and counter-examples are welcome by email.

── more in #ai-infrastructure 4 stories · sorted by recency
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/a-case-for-tracing-b…] indexed:0 read:16min 2026-05-27 ·