PyTorch Triton Kernel Transparent Tracing and Compilation PyTorch has introduced transparent tracing and compilation for Triton kernels, allowing custom operations to be visible to the compiler for optimization. The framework now supports compiling Triton kernels through `torch.compile` and `torch.export` workflows, with the latter requiring specific registration patterns using `@triton_op` and `wrap_triton` for successful export. This advancement enables Triton kernels to be optimized by the PyTorch compiler and used in C++/CUDA environments through AOTInductor, overcoming previous limitations where custom operations were treated as opaque. PyTorch Triton Kernel Transparent Tracing and Compilation Introduction PyTorch allows the user to create custom operations using torch.library.custom op https://docs.pytorch.org/docs/2.12/library.html torch.library.custom op or C++/CUDA custom function and class /blog/PyTorch-Custom-Operation/ . Those custom operations will be treated as opaque operators during tracing and compilation, which means that the internals of those custom operations will not be visible to the PyTorch compiler for optimizations if possible. This is also what usually happens for custom operations in other deep learning inference frameworks, such as TensorRT. PyTorch also allows the user to create Triton kernel functions decorated with @triton.jit , Just-In-Time JIT compile them, and use them in models for training and inference, not only in eager execution but also in torch.compile and torch.export compilation workflows. Triton kernel functions can of course be treated as opaque custom operations with @torch.library.register fake so that the FakeTensor-based symbolic tracing can work. But the disadvantage is that the Triton kernel cannot be optimized by the compiler and Triton JIT compilation is only available in the Python environment. If the user would like to let the Triton kernel have an opportunity to be optimized by the compiler or want to use pre-compiled Triton kernels in C++/CUDA environment, the Triton kernel implementation must be visible to the compiler. In this blog post, I will discuss how to make Triton kernels visible to tracing and compilation by torch.compile , torch.export , and AOTInductor. PyTorch Triton Kernel Transparent Tracing and Compilation In the following example, I created a simple SiLU Triton kernel triton silu kernel and wrapped it in two different Python functions. The first function triton silu triton op is registered as a custom operation with @triton op and used wrap triton to wrap the Triton kernel, which means it will be treated as an opaque Triton operator during tracing and compilation. The second function triton silu pytorch op is not registered as a custom operation with @triton op and no wrap triton is used. It turns out that both triton silu triton op and triton silu pytorch op can be traced and compiled by torch.compile fullgraph=true . However, for torch.export , the Triton kernel triton silu kernel can only be exported in the following three cases: triton silu triton op registered with @triton op and wrapped with wrap triton and torch.export with strict=False . triton silu triton op registered with @triton op and wrapped with wrap triton and torch.export with strict=True . triton silu pytorch op not registered with @triton op and no wrap triton and torch.export with strict=True . In the case of triton silu pytorch op not registered with @triton op and no wrap triton and torch.export with strict=True , it works because torch.export with strict=True uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation. Triton Kernel, torch.compile , torch.export , and AOTInductor 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 | python import argparseimport copyimport osimport randomimport shutilimport torchimport torch.profilerimport tritonimport triton.language as tlfrom torch.export import export, Dimfrom torch.library import triton op, wrap triton 1. Define the Pure Triton Kernel@triton.jitdef triton silu kernel in ptr, out ptr, n elements, BLOCK SIZE: tl.constexpr : pid = tl.program id axis=0 block start = pid BLOCK SIZE offsets = block start + tl.arange 0, BLOCK SIZE mask = offsets < n elements x = tl.load in ptr + offsets, mask=mask Keep everything in float32 for the math x f32 = x.to tl.float32 sigmoid x = 1.0 / 1.0 + tl.exp -x f32 Multiply in float32, THEN cast to bfloat16 out = x f32 sigmoid x .to x.dtype tl.store out ptr + offsets, out, mask=mask 2. Register with triton op to ensure transparent tracing by torch.export@triton op "custom ops::triton silu triton op", mutates args={} def triton silu triton op x: torch.Tensor - torch.Tensor: Enforce contiguous memory to ensure 1D pointer arithmetic is safe assert x.is contiguous , "Input tensor must be contiguous" out = torch.empty like x n elements = x.numel def grid meta : return triton.cdiv n elements, meta "BLOCK SIZE" , wrap triton triton silu kernel grid x, out, n elements, BLOCK SIZE=1024 return out Alternative PyTorch operator that uses the same Triton kernel but is not registered with triton op.def triton silu pytorch op x: torch.Tensor - torch.Tensor: Enforce contiguous memory to ensure 1D pointer arithmetic is safe assert x.is contiguous , "Input tensor must be contiguous" out = torch.empty like x n elements = x.numel def grid meta : return triton.cdiv n elements, meta "BLOCK SIZE" , triton silu kernel grid x, out, n elements, BLOCK SIZE=1024 return outdef pytorch silu eager x: torch.Tensor - torch.Tensor: return torch.nn.functional.silu x 3. Model Architectureclass CustomModel torch.nn.Module : def init self, in features=128, out features=64, silu op="triton silu triton op" : super . init self.linear1 = torch.nn.Linear in features, out features self.linear2 = torch.nn.Linear out features, in features self.silu op = silu op def forward self, x: torch.Tensor - torch.Tensor: x = self.linear1 x Call the transparent operator if self.silu op == "triton silu triton op": x = torch.ops.custom ops.triton silu triton op.default x elif self.silu op == "triton silu pytorch op": x = triton silu pytorch op x elif self.silu op == "pytorch silu eager": x = pytorch silu eager x else: raise ValueError f"Invalid silu op: {self.silu op}" x = self.linear2 x return xdef main : Set random seed for reproducibility random seed = 42 torch.manual seed random seed torch.cuda.manual seed all random seed random.seed random seed For deterministic behavior optional, may affect performance torch.use deterministic algorithms True, warn only=True Remove TorchInductor and Triton kernel cache if present inductor cache dir = "/tmp/torchinductor root/" if os.path.exists inductor cache dir : print f"Removing TorchInductor cache directory: {inductor cache dir}" shutil.rmtree inductor cache dir parser = argparse.ArgumentParser description="Torch Triton Export with AOTInductor" parser.add argument "--use registered triton op", action="store true", default=False, help="Use registered custom Triton op default: False " parser.add argument "--strict export", action="store true", default=False, help="Enable strict export TorchDynamo tracing, default: False " Always use torch.compile fullgraph=True , so remove dynamo fullgraph arg args = parser.parse args use registered triton op = args.use registered triton op strict export = args.strict export silu op = "triton silu triton op" if use registered triton op else "triton silu pytorch op" print f"Using silu op: {silu op}" in features = 128 out features = 64 device = "cuda" dtype = torch.bfloat16 atol = 1e-5 rtol = 1e-5 4. Instantiate model and weights in BF16 model = CustomModel in features=in features, out features=out features, silu op=silu op .to device=device, dtype=dtype Randomly initialize all parameters for param in model.parameters : if param.requires grad: torch.nn.init.uniform param, -1.0, 1.0 model.eval 5. Define dynamic batch dimension constraints batch size can vary from 1 to 1024 batch dim = Dim "batch", min=1, max=1024 dynamic shapes = {"x": {0: batch dim}} 6. Prepare sample input sample input = torch.randn 8, in features, device=device, dtype=dtype 7. Always TorchDynamo fullgraph compile print "--- Cleaning torch.compile cache ---" torch. dynamo.reset Clean torch.compile cache print "--- Compiling model with torch.compile fullgraph=True ---" torch compiled model = torch.compile model, fullgraph=True 8. Trace the model using torch.export AOTInductor path print f"--- Exporting model via torch.export strict={strict export} ---" exported program = export model, args= sample input, , dynamic shapes=dynamic shapes, strict=strict export print "Model successfully traced and exported " print "\nGraph Nodes Extracted:" exported program.graph.print tabular 9. Compile and Package via AOTInductor print "\n--- Compiling and Packaging via AOTInductor ---" output package = "/tmp/compiled model.pt2" Clean up previous artifacts if they exist if os.path.exists output package : os.remove output package Instruct Inductor to accept user-defined Triton kernels natively torch. inductor.config.static launch user defined triton kernels = True Use the unified package compiler. This wraps weights, metadata, and the compiled binary artifact natively inside a single zipped .pt2 container file. package path = torch. inductor.aoti compile and package exported program, package path=output package, print f"Compilation finished Self-contained package saved to: {package path}" 12. Correctness Verification Prepare inference input for correctness and profiling inference input = torch.randn 16, in features, device=device, dtype=dtype Run both models to get outputs for correctness check with torch.no grad : torch compiled output = torch compiled model inference input compiled runner = torch. inductor.aoti load package package path with torch.no grad : aotinductor compiled output = compiled runner inference input Reference output: copy model and set silu op to pytorch silu eager reference model = copy.deepcopy model reference model.silu op = "pytorch silu eager" reference model.eval with torch.no grad : eager output = reference model inference input is torch compile correct = torch.allclose torch compiled output, eager output, atol=atol, rtol=rtol is aotinductor correct = torch.allclose aotinductor compiled output, eager output, atol=atol, rtol=rtol is outputs match = torch.allclose torch compiled output, aotinductor compiled output, atol=atol, rtol=rtol print f"torch compiled model output shape: {torch compiled output.shape}" print f"aotinductor compiled model output shape: {aotinductor compiled output.shape}" print f"eager output shape: {eager output.shape}" print f"torch.compile correctness vs eager? - {is torch compile correct} " print f"aotinductor correctness vs eager? - {is aotinductor correct} " print f"torch.compile vs aotinductor match? - {is outputs match} " --- Run profiling at the end of the program --- Define file paths for profiling traces torch compile profiler path = "./torch compile profiler.json" aotinductor profiler path = "./aotinductor profiler.json" print "\n--- Running torch compiled model fullgraph=True with profiler ---" activities = torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, warmup = 3 steps = 5 schedule = torch.profiler.schedule wait=0, warmup=warmup, active=steps, repeat=1 with torch.profiler.profile activities=activities, schedule=schedule, record shapes=True, with flops=True, as prof: for step in range warmup + steps : with torch.profiler.record function f"step {step}" : with torch.no grad : torch compiled output = torch compiled model inference input prof.step prof.export chrome trace torch compile profiler path print f"Profiling trace for torch compiled model saved to {torch compile profiler path}" print "\n--- Loading AOTInductor Compiled Model Package & Running Inference with profiler ---" compiled runner = torch. inductor.aoti load package package path with torch.profiler.profile activities=activities, schedule=schedule, record shapes=True, with flops=True, as prof: for step in range warmup + steps : with torch.profiler.record function f"step {step}" : with torch.no grad : aotinductor compiled output = compiled runner inference input prof.step prof.export chrome trace aotinductor profiler path print f"Profiling trace for aotinductor compiled output saved to {aotinductor profiler path}" if name == " main ": main | triton silu pytorch op and strict=False In this case, we encountered an error during torch.export , informing us that we should wrap the Triton kernel as an opaque operator, which is very confusing and violates our purpose. 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 | bash $ python torch triton export aotinductor.pyRemoving TorchInductor cache directory: /tmp/torchinductor root/Using silu op: triton silu pytorch op--- Cleaning torch.compile cache ------ Compiling model with torch.compile fullgraph=True ------ Exporting model via torch.export strict=False ---Traceback most recent call last : File "/mnt/torch triton export aotinductor.py", line 249, in