cd /news/machine-learning/pytorch-triton-kernel-transparent-tr… · home topics machine-learning article
[ARTICLE · art-23623] src=leimao.github.io pub= topic=machine-learning verified=true sentiment=· neutral

PyTorch Triton Kernel Transparent Tracing and Compilation

PyTorch has introduced transparent tracing and compilation for Triton kernels, allowing custom operations to be visible to the compiler for optimization. The framework now supports compiling Triton kernels through `torch.compile` and `torch.export` workflows, with the latter requiring specific registration patterns using `@triton_op` and `wrap_triton` for successful export. This advancement enables Triton kernels to be optimized by the PyTorch compiler and used in C++/CUDA environments through AOTInductor, overcoming previous limitations where custom operations were treated as opaque.

read13 min publishedMay 22, 2026

Introduction

PyTorch allows the user to create custom operations using torch.library.custom_op or

C++/CUDA custom function and class. Those custom operations will be treated as opaque operators during tracing and compilation, which means that the internals of those custom operations will not be visible to the PyTorch compiler for optimizations if possible. This is also what usually happens for custom operations in other deep learning inference frameworks, such as TensorRT.

PyTorch also allows the user to create Triton kernel functions decorated with @triton.jit

, Just-In-Time (JIT) compile them, and use them in models for training and inference, not only in eager execution but also in torch.compile

and torch.export

compilation workflows. Triton kernel functions can of course be treated as opaque custom operations with @torch.library.register_fake

so that the FakeTensor-based symbolic tracing can work. But the disadvantage is that the Triton kernel cannot be optimized by the compiler and Triton JIT compilation is only available in the Python environment.

If the user would like to let the Triton kernel have an opportunity to be optimized by the compiler or want to use pre-compiled Triton kernels in C++/CUDA environment, the Triton kernel implementation must be visible to the compiler. In this blog post, I will discuss how to make Triton kernels visible to tracing and compilation by torch.compile

, torch.export

, and AOTInductor.

PyTorch Triton Kernel Transparent Tracing and Compilation

In the following example, I created a simple SiLU Triton kernel triton_silu_kernel

and wrapped it in two different Python functions. The first function triton_silu_triton_op

is registered as a custom operation with @triton_op

and used wrap_triton

to wrap the Triton kernel, which means it will be treated as an opaque Triton operator during tracing and compilation. The second function triton_silu_pytorch_op

is not registered as a custom operation with @triton_op

and no wrap_triton

is used.

It turns out that both triton_silu_triton_op

and triton_silu_pytorch_op

can be traced and compiled by torch.compile(fullgraph=true)

. However, for torch.export

, the Triton kernel triton_silu_kernel

can only be exported in the following three cases:

triton_silu_triton_op

(registered with@triton_op

and wrapped withwrap_triton

) andtorch.export

withstrict=False

.triton_silu_triton_op

(registered with@triton_op

and wrapped withwrap_triton

) andtorch.export

withstrict=True

.triton_silu_pytorch_op

(not registered with@triton_op

and nowrap_triton

) andtorch.export

withstrict=True

.

In the case of triton_silu_pytorch_op

(not registered with @triton_op

and no wrap_triton

) and torch.export

with strict=True

, it works because torch.export

with strict=True

uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.

Triton Kernel, torch.compile

, torch.export

, and AOTInductor

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303

|

import argparseimport copyimport osimport randomimport shutilimport torchimport torch.profilerimport tritonimport triton.language as tlfrom torch.export import export, Dimfrom torch.library import triton_op, wrap_triton# 1. Define the Pure Triton Kernel@triton.jitdef triton_silu_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):    pid = tl.program_id(axis=0)    block_start = pid * BLOCK_SIZE    offsets = block_start + tl.arange(0, BLOCK_SIZE)    mask = offsets < n_elements    x = tl.load(in_ptr + offsets, mask=mask)    # Keep everything in float32 for the math    x_f32 = x.to(tl.float32)    sigmoid_x = 1.0 / (1.0 + tl.exp(-x_f32))    # Multiply in float32, THEN cast to bfloat16    out = (x_f32 * sigmoid_x).to(x.dtype)    tl.store(out_ptr + offsets, out, mask=mask)# 2. Register with triton_op to ensure transparent tracing by torch.export@triton_op("custom_ops::triton_silu_triton_op", mutates_args={})def triton_silu_triton_op(x: torch.Tensor) -> torch.Tensor:    # Enforce contiguous memory to ensure 1D pointer arithmetic is safe    assert x.is_contiguous(), "Input tensor must be contiguous"    out = torch.empty_like(x)    n_elements = x.numel()    def grid(meta):        return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )    wrap_triton(triton_silu_kernel)[grid](x, out, n_elements, BLOCK_SIZE=1024)    return out# Alternative PyTorch operator that uses the same Triton kernel but is not registered with triton_op.def triton_silu_pytorch_op(x: torch.Tensor) -> torch.Tensor:    # Enforce contiguous memory to ensure 1D pointer arithmetic is safe    assert x.is_contiguous(), "Input tensor must be contiguous"    out = torch.empty_like(x)    n_elements = x.numel()    def grid(meta):        return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )    triton_silu_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)    return outdef pytorch_silu_eager(x: torch.Tensor) -> torch.Tensor:    return torch.nn.functional.silu(x)# 3. Model Architectureclass CustomModel(torch.nn.Module):    def __init__(self,                 in_features=128,                 out_features=64,                 silu_op="triton_silu_triton_op"):        super().__init__()        self.linear1 = torch.nn.Linear(in_features, out_features)        self.linear2 = torch.nn.Linear(out_features, in_features)        self.silu_op = silu_op    def forward(self, x: torch.Tensor) -> torch.Tensor:        x = self.linear1(x)        # Call the transparent operator        if self.silu_op == "triton_silu_triton_op":            x = torch.ops.custom_ops.triton_silu_triton_op.default(x)        elif self.silu_op == "triton_silu_pytorch_op":            x = triton_silu_pytorch_op(x)        elif self.silu_op == "pytorch_silu_eager":            x = pytorch_silu_eager(x)        else:            raise ValueError(f"Invalid silu_op: {self.silu_op}")        x = self.linear2(x)        return xdef main():    # Set random seed for reproducibility    random_seed = 42    torch.manual_seed(random_seed)    torch.cuda.manual_seed_all(random_seed)    random.seed(random_seed)    # For deterministic behavior (optional, may affect performance)    torch.use_deterministic_algorithms(True, warn_only=True)    # Remove TorchInductor and Triton kernel cache if present    inductor_cache_dir = "/tmp/torchinductor_root/"    if os.path.exists(inductor_cache_dir):        print(f"Removing TorchInductor cache directory: {inductor_cache_dir}")        shutil.rmtree(inductor_cache_dir)    parser = argparse.ArgumentParser(        description="Torch Triton Export with AOTInductor")    parser.add_argument(        "--use_registered_triton_op",        action="store_true",        default=False,        help="Use registered custom Triton op (default: False)")    parser.add_argument(        "--strict_export",        action="store_true",        default=False,        help="Enable strict export (TorchDynamo tracing, default: False)")    # Always use torch.compile(fullgraph=True), so remove dynamo_fullgraph arg    args = parser.parse_args()    use_registered_triton_op = args.use_registered_triton_op    strict_export = args.strict_export    silu_op = "triton_silu_triton_op" if use_registered_triton_op else "triton_silu_pytorch_op"    print(f"Using silu_op: {silu_op}")    in_features = 128    out_features = 64    device = "cuda"    dtype = torch.bfloat16    atol = 1e-5    rtol = 1e-5    # 4. Instantiate model and weights in BF16    model = CustomModel(in_features=in_features,                        out_features=out_features,                        silu_op=silu_op).to(device=device, dtype=dtype)    # Randomly initialize all parameters    for param in model.parameters():        if param.requires_grad:            torch.nn.init.uniform_(param, -1.0, 1.0)    model.eval()    # 5. Define dynamic batch dimension constraints (batch size can vary from 1 to 1024)    batch_dim = Dim("batch", min=1, max=1024)    dynamic_shapes = {"x": {0: batch_dim}}    # 6. Prepare sample input    sample_input = torch.randn(8, in_features, device=device, dtype=dtype)    # 7. Always TorchDynamo fullgraph compile    print("--- Cleaning torch.compile cache ---")    torch._dynamo.reset()  # Clean torch.compile cache    print("--- Compiling model with torch.compile(fullgraph=True) ---")    torch_compiled_model = torch.compile(model, fullgraph=True)    # 8. Trace the model using torch.export (AOTInductor path)    print(f"--- Exporting model via torch.export (strict={strict_export}) ---")    exported_program = export(model,                              args=(sample_input, ),                              dynamic_shapes=dynamic_shapes,                              strict=strict_export)    print("Model successfully traced and exported!")    print("\nGraph Nodes Extracted:")    exported_program.graph.print_tabular()    # 9. Compile and Package via AOTInductor    print("\n--- Compiling and Packaging via AOTInductor ---")    output_package = "/tmp/compiled_model.pt2"    # Clean up previous artifacts if they exist    if os.path.exists(output_package):        os.remove(output_package)    # Instruct Inductor to accept user-defined Triton kernels natively    torch._inductor.config.static_launch_user_defined_triton_kernels = True    # Use the unified package compiler. This wraps weights, metadata,    # and the compiled binary artifact natively inside a single zipped .pt2 container file.    package_path = torch._inductor.aoti_compile_and_package(        exported_program,        package_path=output_package,    )    print(        f"Compilation finished! Self-contained package saved to: {package_path}"    )    # 12. Correctness Verification    # Prepare inference input for correctness and profiling    inference_input = torch.randn(16, in_features, device=device, dtype=dtype)    # Run both models to get outputs for correctness check    with torch.no_grad():        torch_compiled_output = torch_compiled_model(inference_input)    compiled_runner = torch._inductor.aoti_load_package(package_path)    with torch.no_grad():        aotinductor_compiled_output = compiled_runner(inference_input)    # Reference output: copy model and set silu_op to pytorch_silu_eager    reference_model = copy.deepcopy(model)    reference_model.silu_op = "pytorch_silu_eager"    reference_model.eval()    with torch.no_grad():        eager_output = reference_model(inference_input)    is_torch_compile_correct = torch.allclose(torch_compiled_output,                                              eager_output,                                              atol=atol,                                              rtol=rtol)    is_aotinductor_correct = torch.allclose(aotinductor_compiled_output,                                            eager_output,                                            atol=atol,                                            rtol=rtol)    is_outputs_match = torch.allclose(torch_compiled_output,                                      aotinductor_compiled_output,                                      atol=atol,                                      rtol=rtol)    print(f"torch_compiled_model output shape: {torch_compiled_output.shape}")    print(        f"aotinductor_compiled_model output shape: {aotinductor_compiled_output.shape}"    )    print(f"eager output shape:                  {eager_output.shape}")    print(        f"torch.compile correctness vs eager?  -> **{is_torch_compile_correct}**"    )    print(        f"aotinductor correctness vs eager?    -> **{is_aotinductor_correct}**"    )    print(f"torch.compile vs aotinductor match?  -> **{is_outputs_match}**")    # --- Run profiling at the end of the program ---    # Define file paths for profiling traces    torch_compile_profiler_path = "./torch_compile_profiler.json"    aotinductor_profiler_path = "./aotinductor_profiler.json"    print(        "\n--- Running torch_compiled_model (fullgraph=True) with profiler ---"    )    activities = [        torch.profiler.ProfilerActivity.CPU,        torch.profiler.ProfilerActivity.CUDA,    ]    warmup = 3    steps = 5    schedule = torch.profiler.schedule(wait=0,                                       warmup=warmup,                                       active=steps,                                       repeat=1)    with torch.profiler.profile(            activities=activities,            schedule=schedule,            record_shapes=True,            with_flops=True,    ) as prof:        for step in range(warmup + steps):            with torch.profiler.record_function(f"step_{step}"):                with torch.no_grad():                    torch_compiled_output = torch_compiled_model(                        inference_input)            prof.step()    prof.export_chrome_trace(torch_compile_profiler_path)    print(        f"Profiling trace for torch_compiled_model saved to {torch_compile_profiler_path}"    )    print(        "\n---  AOTInductor Compiled Model Package & Running Inference with profiler ---"    )    compiled_runner = torch._inductor.aoti_load_package(package_path)    with torch.profiler.profile(            activities=activities,            schedule=schedule,            record_shapes=True,            with_flops=True,    ) as prof:        for step in range(warmup + steps):            with torch.profiler.record_function(f"step_{step}"):                with torch.no_grad():                    aotinductor_compiled_output = compiled_runner(                        inference_input)            prof.step()    prof.export_chrome_trace(aotinductor_profiler_path)    print(        f"Profiling trace for aotinductor_compiled_output saved to {aotinductor_profiler_path}"    )if __name__ == "__main__":    main()

|

triton_silu_pytorch_op

and strict=False

In this case, we encountered an error during torch.export

, informing us that we should wrap the Triton kernel as an opaque operator, which is very confusing and violates our purpose.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155

|

$ python torch_triton_export_aotinductor.pyRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_pytorch_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=False) ---Traceback (most recent call last):  File "/mnt/torch_triton_export_aotinductor.py", line 249, in <module>    main()  File "/mnt/torch_triton_export_aotinductor.py", line 169, in main    exported_program = export(model,                       ^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 205, in export    raise e  File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 171, in export    return _export(           ^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1343, in wrapper    raise e  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1309, in wrapper    ep = fn(*args, **kwargs)         ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper    return fn(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function    return function(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2508, in _export    ep = _export_for_training(         ^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1343, in wrapper    raise e  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1309, in wrapper    ep = fn(*args, **kwargs)         ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper    return fn(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2296, in _export_for_training    export_artifact = export_func(                      ^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2225, in _non_strict_export    aten_export_artifact = _to_aten_func(                           ^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2002, in _export_to_aten_ir_make_fx    gm, graph_signature = transform(_make_fx_helper)(                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2132, in _aot_export_non_strict    gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1910, in _make_fx_helper    gm = make_fx(         ^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2826, in wrapped    return make_fx_tracer.trace(f, *args)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2727, in trace    return self._trace_inner(f, *args)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2688, in _trace_inner    t = dispatch_trace(        ^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 54, in inner    return disable_fn(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1255, in _fn    return fn(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1533, in dispatch_trace    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2264, in trace    res = super().trace(root, concrete_args)          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1255, in _fn    return fn(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 890, in trace    (self.create_arg(fn(*args)),),                     ^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1603, in wrapped    out = f(*tensors)  # type:ignore[call-arg]          ^^^^^^^^^^^  File "<string>", line 1, in <lambda>  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1794, in wrapped_fn    return tuple(flat_fn(*args))                 ^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 204, in flat_fn    tree_out = fn(*args, **kwargs)               ^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1507, in functional_call    out = mod(*args[params_len:], **kwargs)          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper    return self.call_module(mod, forward, args, kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2353, in call_module    return Tracer.call_module(self, m, forward, args, kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 572, in call_module    ret_val = forward(*args, **kwargs)              ^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 857, in forward    return _orig_module_call(mod, *args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl    return self._call_impl(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl    return forward_call(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2116, in forward    tree_out = mod(*args, **kwargs)               ^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper    return self.call_module(mod, forward, args, kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2353, in call_module    return Tracer.call_module(self, m, forward, args, kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 572, in call_module    ret_val = forward(*args, **kwargs)              ^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 857, in forward    return _orig_module_call(mod, *args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl    return self._call_impl(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl    return forward_call(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/mnt/torch_triton_export_aotinductor.py", line 87, in forward    x = triton_silu_pytorch_op(x)        ^^^^^^^^^^^^^^^^^^^^^^^^^  File "/mnt/torch_triton_export_aotinductor.py", line 61, in triton_silu_pytorch_op    triton_silu_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 370, in <lambda>    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 723, in run    bound_args, specialization, options = binder(*args, **kwargs)                                          ^^^^^^^^^^^^^^^^^^^^^^^  File "<string>", line 4, in dynamic_func  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1654, in __torch_function__    return func(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1725, in __torch_function__    return func(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/_export/non_strict_utils.py", line 1159, in __torch_function__    return func(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^^^^RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

|

triton_silu_pytorch_op

and strict=True

Using triton_silu_pytorch_op

with torch.export

with strict=True

works because torch.export

with strict=True

uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.

12345678910111213141516171819202122232425262728293031323334353637383940414243

|

$ python torch_triton_export_aotinductor.py --strict_exportRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_pytorch_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=True) ---/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch.  config[key] = copy.deepcopy(getattr(self, key))Model successfully traced and exported!Graph Nodes Extracted:opcode         name                                  target                          args                                            kwargs-------------  ------------------------------------  ------------------------------  ----------------------------------------------  ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------placeholder    p_linear1_weight                      p_linear1_weight                ()                                              {}placeholder    p_linear1_bias                        p_linear1_bias                  ()                                              {}placeholder    p_linear2_weight                      p_linear2_weight                ()                                              {}placeholder    p_linear2_bias                        p_linear2_bias                  ()                                              {}placeholder    x                                     x                               ()                                              {}call_function  sym_size_int                          aten.sym_size.int               (x, 0)                                          {}call_function  linear                                aten.linear.default             (x, p_linear1_weight, p_linear1_bias)           {}call_function  empty_like                            aten.empty_like.default         (linear,)                                       {'pin_memory': False}call_function  mul                                   <built-in function mul>         (64, sym_size_int)                              {}call_function  add                                   <built-in function add>         (mul, 1024)                                     {}call_function  sub                                   <built-in function sub>         (add, 1)                                        {}call_function  floordiv                              <built-in function floordiv>    (sub, 1024)                                     {}call_function  triton_kernel_wrapper_mutation_proxy  triton_kernel_wrapper_mutation  ()                                              {'kernel_idx': 0, 'constant_args_idx': 0, 'grid': [(floordiv, 1, 1)], 'tma_descriptor_metadata': {}, 'kwargs': {'in_ptr': linear, 'out_ptr': empty_like, 'n_elements': mul}}call_function  linear_1                              aten.linear.default             (empty_like, p_linear2_weight, p_linear2_bias)  {}output         output                                output                          ((linear_1,),)                                  {}--- Compiling and Packaging via AOTInductor ---/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.  return cls.__new__(cls, *args)Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2--- Running torch_compiled_model (fullgraph=True) ------  AOTInductor Compiled Model Package & Running Inference ---torch_compiled_model output shape: torch.Size([16, 128])aotinductor_compiled_model output shape: torch.Size([16, 128])eager output shape:                  torch.Size([16, 128])torch.compile correctness vs eager?  -> **True** aotinductor correctness vs eager?    -> **True** torch.compile vs aotinductor match?  -> **True**

|

triton_silu_triton_op

and strict=False

When triton_silu_triton_op

is registered with @triton_op

and wrapped with wrap_triton

, it can be exported by torch.export

regardless of whether strict=True

or strict=False

is used.

12345678910111213141516171819202122232425262728293031323334353637

|

$ python torch_triton_export_aotinductor.py --use_registered_triton_opRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_triton_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=False) ---Model successfully traced and exported!Graph Nodes Extracted:opcode         name                   target                                    args                                                       kwargs-------------  ---------------------  ----------------------------------------  ---------------------------------------------------------  --------placeholder    p_linear1_weight       p_linear1_weight                          ()                                                         {}placeholder    p_linear1_bias         p_linear1_bias                            ()                                                         {}placeholder    p_linear2_weight       p_linear2_weight                          ()                                                         {}placeholder    p_linear2_bias         p_linear2_bias                            ()                                                         {}placeholder    x                      x                                         ()                                                         {}call_function  linear                 aten.linear.default                       (x, p_linear1_weight, p_linear1_bias)                      {}call_function  triton_silu_triton_op  custom_ops.triton_silu_triton_op.default  (linear,)                                                  {}call_function  linear_1               aten.linear.default                       (triton_silu_triton_op, p_linear2_weight, p_linear2_bias)  {}output         output                 output                                    ((linear_1,),)                                             {}--- Compiling and Packaging via AOTInductor ---/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.  return cls.__new__(cls, *args)/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch.  config[key] = copy.deepcopy(getattr(self, key))Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2--- Running torch_compiled_model (fullgraph=True) ------  AOTInductor Compiled Model Package & Running Inference ---torch_compiled_model output shape: torch.Size([16, 128])aotinductor_compiled_model output shape: torch.Size([16, 128])eager output shape:                  torch.Size([16, 128])torch.compile correctness vs eager?  -> **True** aotinductor correctness vs eager?    -> **True** torch.compile vs aotinductor match?  -> **True**

|

triton_silu_triton_op

and strict=True

When triton_silu_triton_op

is registered with @triton_op

and wrapped with wrap_triton

, it can be exported by torch.export

regardless of whether strict=True

or strict=False

is used.

12345678910111213141516171819202122232425262728293031323334353637

|

$ python torch_triton_export_aotinductor.py --use_registered_triton_op --strict_exportRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_triton_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=True) ---/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch.  config[key] = copy.deepcopy(getattr(self, key))Model successfully traced and exported!Graph Nodes Extracted:opcode         name                   target                                    args                                                       kwargs-------------  ---------------------  ----------------------------------------  ---------------------------------------------------------  --------placeholder    p_linear1_weight       p_linear1_weight                          ()                                                         {}placeholder    p_linear1_bias         p_linear1_bias                            ()                                                         {}placeholder    p_linear2_weight       p_linear2_weight                          ()                                                         {}placeholder    p_linear2_bias         p_linear2_bias                            ()                                                         {}placeholder    x                      x                                         ()                                                         {}call_function  linear                 aten.linear.default                       (x, p_linear1_weight, p_linear1_bias)                      {}call_function  triton_silu_triton_op  custom_ops.triton_silu_triton_op.default  (linear,)                                                  {}call_function  linear_1               aten.linear.default                       (triton_silu_triton_op, p_linear2_weight, p_linear2_bias)  {}output         output                 output                                    ((linear_1,),)                                             {}--- Compiling and Packaging via AOTInductor ---/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.  return cls.__new__(cls, *args)Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2--- Running torch_compiled_model (fullgraph=True) ------  AOTInductor Compiled Model Package & Running Inference ---torch_compiled_model output shape: torch.Size([16, 128])aotinductor_compiled_model output shape: torch.Size([16, 128])eager output shape:                  torch.Size([16, 128])torch.compile correctness vs eager?  -> **True** aotinductor correctness vs eager?    -> **True** torch.compile vs aotinductor match?  -> **True**

|

Conclusions

We can use Triton kernels not only via JIT in PyTorch but also via pre-compilation in AOTInductor.

References

PyTorch Triton Kernel Transparent Tracing and Compilation

https://leimao.github.io/blog/PyTorch-Triton-Kernel-Transparent-Tracing-and-Compilation/

── more in #machine-learning 4 stories · sorted by recency
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/pytorch-triton-kerne…] indexed:0 read:13min 2026-05-22 ·