Introduction
PyTorch allows the user to create custom operations using torch.library.custom_op or
C++/CUDA custom function and class. Those custom operations will be treated as opaque operators during tracing and compilation, which means that the internals of those custom operations will not be visible to the PyTorch compiler for optimizations if possible. This is also what usually happens for custom operations in other deep learning inference frameworks, such as TensorRT.
PyTorch also allows the user to create Triton kernel functions decorated with @triton.jit
, Just-In-Time (JIT) compile them, and use them in models for training and inference, not only in eager execution but also in torch.compile
and torch.export
compilation workflows. Triton kernel functions can of course be treated as opaque custom operations with @torch.library.register_fake
so that the FakeTensor-based symbolic tracing can work. But the disadvantage is that the Triton kernel cannot be optimized by the compiler and Triton JIT compilation is only available in the Python environment.
If the user would like to let the Triton kernel have an opportunity to be optimized by the compiler or want to use pre-compiled Triton kernels in C++/CUDA environment, the Triton kernel implementation must be visible to the compiler. In this blog post, I will discuss how to make Triton kernels visible to tracing and compilation by torch.compile
, torch.export
, and AOTInductor.
PyTorch Triton Kernel Transparent Tracing and Compilation
In the following example, I created a simple SiLU Triton kernel triton_silu_kernel
and wrapped it in two different Python functions. The first function triton_silu_triton_op
is registered as a custom operation with @triton_op
and used wrap_triton
to wrap the Triton kernel, which means it will be treated as an opaque Triton operator during tracing and compilation. The second function triton_silu_pytorch_op
is not registered as a custom operation with @triton_op
and no wrap_triton
is used.
It turns out that both triton_silu_triton_op
and triton_silu_pytorch_op
can be traced and compiled by torch.compile(fullgraph=true)
. However, for torch.export
, the Triton kernel triton_silu_kernel
can only be exported in the following three cases:
triton_silu_triton_op
(registered with@triton_op
and wrapped withwrap_triton
) andtorch.export
withstrict=False
.triton_silu_triton_op
(registered with@triton_op
and wrapped withwrap_triton
) andtorch.export
withstrict=True
.triton_silu_pytorch_op
(not registered with@triton_op
and nowrap_triton
) andtorch.export
withstrict=True
.
In the case of triton_silu_pytorch_op
(not registered with @triton_op
and no wrap_triton
) and torch.export
with strict=True
, it works because torch.export
with strict=True
uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.
Triton Kernel, torch.compile
, torch.export
, and AOTInductor
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
|
import argparseimport copyimport osimport randomimport shutilimport torchimport torch.profilerimport tritonimport triton.language as tlfrom torch.export import export, Dimfrom torch.library import triton_op, wrap_triton# 1. Define the Pure Triton Kernel@triton.jitdef triton_silu_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr + offsets, mask=mask) # Keep everything in float32 for the math x_f32 = x.to(tl.float32) sigmoid_x = 1.0 / (1.0 + tl.exp(-x_f32)) # Multiply in float32, THEN cast to bfloat16 out = (x_f32 * sigmoid_x).to(x.dtype) tl.store(out_ptr + offsets, out, mask=mask)# 2. Register with triton_op to ensure transparent tracing by torch.export@triton_op("custom_ops::triton_silu_triton_op", mutates_args={})def triton_silu_triton_op(x: torch.Tensor) -> torch.Tensor: # Enforce contiguous memory to ensure 1D pointer arithmetic is safe assert x.is_contiguous(), "Input tensor must be contiguous" out = torch.empty_like(x) n_elements = x.numel() def grid(meta): return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) wrap_triton(triton_silu_kernel)[grid](x, out, n_elements, BLOCK_SIZE=1024) return out# Alternative PyTorch operator that uses the same Triton kernel but is not registered with triton_op.def triton_silu_pytorch_op(x: torch.Tensor) -> torch.Tensor: # Enforce contiguous memory to ensure 1D pointer arithmetic is safe assert x.is_contiguous(), "Input tensor must be contiguous" out = torch.empty_like(x) n_elements = x.numel() def grid(meta): return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) triton_silu_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) return outdef pytorch_silu_eager(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.silu(x)# 3. Model Architectureclass CustomModel(torch.nn.Module): def __init__(self, in_features=128, out_features=64, silu_op="triton_silu_triton_op"): super().__init__() self.linear1 = torch.nn.Linear(in_features, out_features) self.linear2 = torch.nn.Linear(out_features, in_features) self.silu_op = silu_op def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear1(x) # Call the transparent operator if self.silu_op == "triton_silu_triton_op": x = torch.ops.custom_ops.triton_silu_triton_op.default(x) elif self.silu_op == "triton_silu_pytorch_op": x = triton_silu_pytorch_op(x) elif self.silu_op == "pytorch_silu_eager": x = pytorch_silu_eager(x) else: raise ValueError(f"Invalid silu_op: {self.silu_op}") x = self.linear2(x) return xdef main(): # Set random seed for reproducibility random_seed = 42 torch.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) random.seed(random_seed) # For deterministic behavior (optional, may affect performance) torch.use_deterministic_algorithms(True, warn_only=True) # Remove TorchInductor and Triton kernel cache if present inductor_cache_dir = "/tmp/torchinductor_root/" if os.path.exists(inductor_cache_dir): print(f"Removing TorchInductor cache directory: {inductor_cache_dir}") shutil.rmtree(inductor_cache_dir) parser = argparse.ArgumentParser( description="Torch Triton Export with AOTInductor") parser.add_argument( "--use_registered_triton_op", action="store_true", default=False, help="Use registered custom Triton op (default: False)") parser.add_argument( "--strict_export", action="store_true", default=False, help="Enable strict export (TorchDynamo tracing, default: False)") # Always use torch.compile(fullgraph=True), so remove dynamo_fullgraph arg args = parser.parse_args() use_registered_triton_op = args.use_registered_triton_op strict_export = args.strict_export silu_op = "triton_silu_triton_op" if use_registered_triton_op else "triton_silu_pytorch_op" print(f"Using silu_op: {silu_op}") in_features = 128 out_features = 64 device = "cuda" dtype = torch.bfloat16 atol = 1e-5 rtol = 1e-5 # 4. Instantiate model and weights in BF16 model = CustomModel(in_features=in_features, out_features=out_features, silu_op=silu_op).to(device=device, dtype=dtype) # Randomly initialize all parameters for param in model.parameters(): if param.requires_grad: torch.nn.init.uniform_(param, -1.0, 1.0) model.eval() # 5. Define dynamic batch dimension constraints (batch size can vary from 1 to 1024) batch_dim = Dim("batch", min=1, max=1024) dynamic_shapes = {"x": {0: batch_dim}} # 6. Prepare sample input sample_input = torch.randn(8, in_features, device=device, dtype=dtype) # 7. Always TorchDynamo fullgraph compile print("--- Cleaning torch.compile cache ---") torch._dynamo.reset() # Clean torch.compile cache print("--- Compiling model with torch.compile(fullgraph=True) ---") torch_compiled_model = torch.compile(model, fullgraph=True) # 8. Trace the model using torch.export (AOTInductor path) print(f"--- Exporting model via torch.export (strict={strict_export}) ---") exported_program = export(model, args=(sample_input, ), dynamic_shapes=dynamic_shapes, strict=strict_export) print("Model successfully traced and exported!") print("\nGraph Nodes Extracted:") exported_program.graph.print_tabular() # 9. Compile and Package via AOTInductor print("\n--- Compiling and Packaging via AOTInductor ---") output_package = "/tmp/compiled_model.pt2" # Clean up previous artifacts if they exist if os.path.exists(output_package): os.remove(output_package) # Instruct Inductor to accept user-defined Triton kernels natively torch._inductor.config.static_launch_user_defined_triton_kernels = True # Use the unified package compiler. This wraps weights, metadata, # and the compiled binary artifact natively inside a single zipped .pt2 container file. package_path = torch._inductor.aoti_compile_and_package( exported_program, package_path=output_package, ) print( f"Compilation finished! Self-contained package saved to: {package_path}" ) # 12. Correctness Verification # Prepare inference input for correctness and profiling inference_input = torch.randn(16, in_features, device=device, dtype=dtype) # Run both models to get outputs for correctness check with torch.no_grad(): torch_compiled_output = torch_compiled_model(inference_input) compiled_runner = torch._inductor.aoti_load_package(package_path) with torch.no_grad(): aotinductor_compiled_output = compiled_runner(inference_input) # Reference output: copy model and set silu_op to pytorch_silu_eager reference_model = copy.deepcopy(model) reference_model.silu_op = "pytorch_silu_eager" reference_model.eval() with torch.no_grad(): eager_output = reference_model(inference_input) is_torch_compile_correct = torch.allclose(torch_compiled_output, eager_output, atol=atol, rtol=rtol) is_aotinductor_correct = torch.allclose(aotinductor_compiled_output, eager_output, atol=atol, rtol=rtol) is_outputs_match = torch.allclose(torch_compiled_output, aotinductor_compiled_output, atol=atol, rtol=rtol) print(f"torch_compiled_model output shape: {torch_compiled_output.shape}") print( f"aotinductor_compiled_model output shape: {aotinductor_compiled_output.shape}" ) print(f"eager output shape: {eager_output.shape}") print( f"torch.compile correctness vs eager? -> **{is_torch_compile_correct}**" ) print( f"aotinductor correctness vs eager? -> **{is_aotinductor_correct}**" ) print(f"torch.compile vs aotinductor match? -> **{is_outputs_match}**") # --- Run profiling at the end of the program --- # Define file paths for profiling traces torch_compile_profiler_path = "./torch_compile_profiler.json" aotinductor_profiler_path = "./aotinductor_profiler.json" print( "\n--- Running torch_compiled_model (fullgraph=True) with profiler ---" ) activities = [ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] warmup = 3 steps = 5 schedule = torch.profiler.schedule(wait=0, warmup=warmup, active=steps, repeat=1) with torch.profiler.profile( activities=activities, schedule=schedule, record_shapes=True, with_flops=True, ) as prof: for step in range(warmup + steps): with torch.profiler.record_function(f"step_{step}"): with torch.no_grad(): torch_compiled_output = torch_compiled_model( inference_input) prof.step() prof.export_chrome_trace(torch_compile_profiler_path) print( f"Profiling trace for torch_compiled_model saved to {torch_compile_profiler_path}" ) print( "\n--- AOTInductor Compiled Model Package & Running Inference with profiler ---" ) compiled_runner = torch._inductor.aoti_load_package(package_path) with torch.profiler.profile( activities=activities, schedule=schedule, record_shapes=True, with_flops=True, ) as prof: for step in range(warmup + steps): with torch.profiler.record_function(f"step_{step}"): with torch.no_grad(): aotinductor_compiled_output = compiled_runner( inference_input) prof.step() prof.export_chrome_trace(aotinductor_profiler_path) print( f"Profiling trace for aotinductor_compiled_output saved to {aotinductor_profiler_path}" )if __name__ == "__main__": main()
|
triton_silu_pytorch_op
and strict=False
In this case, we encountered an error during torch.export
, informing us that we should wrap the Triton kernel as an opaque operator, which is very confusing and violates our purpose.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
|
$ python torch_triton_export_aotinductor.pyRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_pytorch_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=False) ---Traceback (most recent call last): File "/mnt/torch_triton_export_aotinductor.py", line 249, in <module> main() File "/mnt/torch_triton_export_aotinductor.py", line 169, in main exported_program = export(model, ^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 205, in export raise e File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 171, in export return _export( ^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1343, in wrapper raise e File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1309, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function return function(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2508, in _export ep = _export_for_training( ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1343, in wrapper raise e File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1309, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2296, in _export_for_training export_artifact = export_func( ^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2225, in _non_strict_export aten_export_artifact = _to_aten_func( ^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2002, in _export_to_aten_ir_make_fx gm, graph_signature = transform(_make_fx_helper)( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2132, in _aot_export_non_strict gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1910, in _make_fx_helper gm = make_fx( ^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2826, in wrapped return make_fx_tracer.trace(f, *args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2727, in trace return self._trace_inner(f, *args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2688, in _trace_inner t = dispatch_trace( ^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1255, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1533, in dispatch_trace graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2264, in trace res = super().trace(root, concrete_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1255, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 890, in trace (self.create_arg(fn(*args)),), ^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1603, in wrapped out = f(*tensors) # type:ignore[call-arg] ^^^^^^^^^^^ File "<string>", line 1, in <lambda> File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1794, in wrapped_fn return tuple(flat_fn(*args)) ^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 204, in flat_fn tree_out = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1507, in functional_call out = mod(*args[params_len:], **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2353, in call_module return Tracer.call_module(self, m, forward, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 572, in call_module ret_val = forward(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 857, in forward return _orig_module_call(mod, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2116, in forward tree_out = mod(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2353, in call_module return Tracer.call_module(self, m, forward, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 572, in call_module ret_val = forward(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 857, in forward return _orig_module_call(mod, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/torch_triton_export_aotinductor.py", line 87, in forward x = triton_silu_pytorch_op(x) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/torch_triton_export_aotinductor.py", line 61, in triton_silu_pytorch_op triton_silu_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 370, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 723, in run bound_args, specialization, options = binder(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "<string>", line 4, in dynamic_func File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1654, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1725, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_export/non_strict_utils.py", line 1159, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
|
triton_silu_pytorch_op
and strict=True
Using triton_silu_pytorch_op
with torch.export
with strict=True
works because torch.export
with strict=True
uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.
12345678910111213141516171819202122232425262728293031323334353637383940414243
|
$ python torch_triton_export_aotinductor.py --strict_exportRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_pytorch_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=True) ---/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch. config[key] = copy.deepcopy(getattr(self, key))Model successfully traced and exported!Graph Nodes Extracted:opcode name target args kwargs------------- ------------------------------------ ------------------------------ ---------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------placeholder p_linear1_weight p_linear1_weight () {}placeholder p_linear1_bias p_linear1_bias () {}placeholder p_linear2_weight p_linear2_weight () {}placeholder p_linear2_bias p_linear2_bias () {}placeholder x x () {}call_function sym_size_int aten.sym_size.int (x, 0) {}call_function linear aten.linear.default (x, p_linear1_weight, p_linear1_bias) {}call_function empty_like aten.empty_like.default (linear,) {'pin_memory': False}call_function mul <built-in function mul> (64, sym_size_int) {}call_function add <built-in function add> (mul, 1024) {}call_function sub <built-in function sub> (add, 1) {}call_function floordiv <built-in function floordiv> (sub, 1024) {}call_function triton_kernel_wrapper_mutation_proxy triton_kernel_wrapper_mutation () {'kernel_idx': 0, 'constant_args_idx': 0, 'grid': [(floordiv, 1, 1)], 'tma_descriptor_metadata': {}, 'kwargs': {'in_ptr': linear, 'out_ptr': empty_like, 'n_elements': mul}}call_function linear_1 aten.linear.default (empty_like, p_linear2_weight, p_linear2_bias) {}output output output ((linear_1,),) {}--- Compiling and Packaging via AOTInductor ---/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return cls.__new__(cls, *args)Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2--- Running torch_compiled_model (fullgraph=True) ------ AOTInductor Compiled Model Package & Running Inference ---torch_compiled_model output shape: torch.Size([16, 128])aotinductor_compiled_model output shape: torch.Size([16, 128])eager output shape: torch.Size([16, 128])torch.compile correctness vs eager? -> **True** aotinductor correctness vs eager? -> **True** torch.compile vs aotinductor match? -> **True**
|
triton_silu_triton_op
and strict=False
When triton_silu_triton_op
is registered with @triton_op
and wrapped with wrap_triton
, it can be exported by torch.export
regardless of whether strict=True
or strict=False
is used.
12345678910111213141516171819202122232425262728293031323334353637
|
$ python torch_triton_export_aotinductor.py --use_registered_triton_opRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_triton_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=False) ---Model successfully traced and exported!Graph Nodes Extracted:opcode name target args kwargs------------- --------------------- ---------------------------------------- --------------------------------------------------------- --------placeholder p_linear1_weight p_linear1_weight () {}placeholder p_linear1_bias p_linear1_bias () {}placeholder p_linear2_weight p_linear2_weight () {}placeholder p_linear2_bias p_linear2_bias () {}placeholder x x () {}call_function linear aten.linear.default (x, p_linear1_weight, p_linear1_bias) {}call_function triton_silu_triton_op custom_ops.triton_silu_triton_op.default (linear,) {}call_function linear_1 aten.linear.default (triton_silu_triton_op, p_linear2_weight, p_linear2_bias) {}output output output ((linear_1,),) {}--- Compiling and Packaging via AOTInductor ---/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return cls.__new__(cls, *args)/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch. config[key] = copy.deepcopy(getattr(self, key))Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2--- Running torch_compiled_model (fullgraph=True) ------ AOTInductor Compiled Model Package & Running Inference ---torch_compiled_model output shape: torch.Size([16, 128])aotinductor_compiled_model output shape: torch.Size([16, 128])eager output shape: torch.Size([16, 128])torch.compile correctness vs eager? -> **True** aotinductor correctness vs eager? -> **True** torch.compile vs aotinductor match? -> **True**
|
triton_silu_triton_op
and strict=True
When triton_silu_triton_op
is registered with @triton_op
and wrapped with wrap_triton
, it can be exported by torch.export
regardless of whether strict=True
or strict=False
is used.
12345678910111213141516171819202122232425262728293031323334353637
|
$ python torch_triton_export_aotinductor.py --use_registered_triton_op --strict_exportRemoving TorchInductor cache directory: /tmp/torchinductor_root/Using silu_op: triton_silu_triton_op--- Cleaning torch.compile cache ------ Compiling model with torch.compile(fullgraph=True) ------ Exporting model via torch.export (strict=True) ---/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch. config[key] = copy.deepcopy(getattr(self, key))Model successfully traced and exported!Graph Nodes Extracted:opcode name target args kwargs------------- --------------------- ---------------------------------------- --------------------------------------------------------- --------placeholder p_linear1_weight p_linear1_weight () {}placeholder p_linear1_bias p_linear1_bias () {}placeholder p_linear2_weight p_linear2_weight () {}placeholder p_linear2_bias p_linear2_bias () {}placeholder x x () {}call_function linear aten.linear.default (x, p_linear1_weight, p_linear1_bias) {}call_function triton_silu_triton_op custom_ops.triton_silu_triton_op.default (linear,) {}call_function linear_1 aten.linear.default (triton_silu_triton_op, p_linear2_weight, p_linear2_bias) {}output output output ((linear_1,),) {}--- Compiling and Packaging via AOTInductor ---/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return cls.__new__(cls, *args)Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2--- Running torch_compiled_model (fullgraph=True) ------ AOTInductor Compiled Model Package & Running Inference ---torch_compiled_model output shape: torch.Size([16, 128])aotinductor_compiled_model output shape: torch.Size([16, 128])eager output shape: torch.Size([16, 128])torch.compile correctness vs eager? -> **True** aotinductor correctness vs eager? -> **True** torch.compile vs aotinductor match? -> **True**
|
Conclusions
We can use Triton kernels not only via JIT in PyTorch but also via pre-compilation in AOTInductor.
References
PyTorch Triton Kernel Transparent Tracing and Compilation
https://leimao.github.io/blog/PyTorch-Triton-Kernel-Transparent-Tracing-and-Compilation/