Introduction
Predicated execution and conditional execution are two different approaches to handling control flow in programming, particularly in the context of parallel computing and GPU programming. Predicated execution involves executing all instructions but only committing the results of those instructions that meet a certain condition, while conditional execution involves executing instructions only if a certain condition is met.
In this blog post, I would like to quickly show a few examples of predicated execution and conditional execution using PyTorch and discuss how to choose between them in different scenarios.
Predicated Execution
A common PyTorch API that uses predicated execution is torch.where
. The torch.where
function takes a condition tensor and two other tensors, and it returns a new tensor where each element is selected from one of the two input tensors based on the corresponding value in the condition tensor.
In the following example, we have two branches of neural network which produces outputs of exactly the same metadata (shape, dtype, device). We can use torch.where
to select the output from one of the two branches based on a condition tensor. However, both branches are executed regardless of the condition.
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
|
import torchimport torch.nn as nnclass BranchA(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(dim, dim), nn.ReLU()) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)class BranchB(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(dim, dim), nn.GELU()) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)class PredicatedModel(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.branch_a = BranchA(dim) self.branch_b = BranchB(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: # Both branches are executed and produce tensors with identical metadata. y_a = self.branch_a(x) y_b = self.branch_b(x) # Use the same scalar predicate as the conditional-execution examples. pred = x.mean() > 0 # Predicated execution: both branches run, then one full output is selected. y = torch.where(pred, y_a, y_b) return yif __name__ == "__main__": torch.manual_seed(0) device = "cuda" batch_size = 8 hidden_dim = 16 x = torch.randn(batch_size, hidden_dim, device=device) with torch.device(device): model = PredicatedModel(hidden_dim) print("=== Predicated Execution ===") print("=== Eager Execution ===") torch.cuda.set_sync_debug_mode(debug_mode="warn") y = model(x) torch.cuda.set_sync_debug_mode(debug_mode="default") compiled_model = torch.compile(model) print("=== Compiled Execution ===") # Warm up y = compiled_model(x) torch.cuda.set_sync_debug_mode(debug_mode="warn") y = compiled_model(x) torch.cuda.set_sync_debug_mode(debug_mode="default")
|
Conditional Execution
In contrast, conditional execution involves executing only the branch of code that meets a certain condition. In this example, we use torch.cond
to achieve conditional execution. The torch.cond
function takes a condition tensor and two functions, and it executes only the function corresponding to the value of the condition tensor.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
|
import torchimport torch.nn as nnclass BranchA(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(dim, dim), nn.ReLU()) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)class BranchB(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(dim, dim), nn.GELU()) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)class ConditionalModel(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.branch_a = BranchA(dim) self.branch_b = BranchB(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: pred = x.mean() > 0 def true_fn(x_: torch.Tensor) -> torch.Tensor: return self.branch_a(x_) def false_fn(x_: torch.Tensor) -> torch.Tensor: return self.branch_b(x_) # torch.cond traces both branches but executes only one branch at runtime. return torch.cond(pred, true_fn, false_fn, (x, ))if __name__ == "__main__": torch.manual_seed(0) torch.cuda.set_sync_debug_mode(debug_mode="warn") device = "cuda" batch_size = 8 hidden_dim = 16 x = torch.randn(batch_size, hidden_dim, device=device) with torch.device(device): model = ConditionalModel(hidden_dim) print("=== Conditional Execution ===") print("=== Eager Execution ===") torch.cuda.set_sync_debug_mode(debug_mode="warn") y = model(x) torch.cuda.set_sync_debug_mode(debug_mode="default") compiled_model = torch.compile(model) print("=== Compiled Execution ===") # Warm up y = compiled_model(x) torch.cuda.set_sync_debug_mode(debug_mode="warn") y = compiled_model(x) torch.cuda.set_sync_debug_mode(debug_mode="default")
|
One caveat of torch.cond
is that it will result in host-device synchronization because the instructions in the two branches are dynamically dispatched to the GPU at runtime.
Changing the execution framework will not eliminate this synchronization. Even if it is compiled by torch.compile
, AOTInductor
or TensorRT
running on GPU or it is jax.lax.cond
in JAX and compiled by XLA
running on both GPU and TPU (v5e), verified by experiments not discussed here, it will still always result in host-device synchronization.
To run conditional execution using jax.lax.cond
, we could use the following script on GPU or TPU platforms.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
|
"""Unified JAX conditional execution script for GPU and TPU.Usage: python conditional_execution_jax.py # Auto-detect device python conditional_execution_jax.py --device gpu # Force GPU python conditional_execution_jax.py --device tpu # Force TPU"""import argparseimport osimport shutilimport sys# Set GPU memory fraction only on GPU devices (no-op on TPU).os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.2"import jaximport jax.numpy as jnpdef init_linear(in_dim: int, out_dim: int, key: jax.Array) -> dict: # Kaiming uniform scale to match PyTorch nn.Linear default. scale = jnp.sqrt(2.0 / in_dim) weight = jax.random.normal(key, (out_dim, in_dim)) * scale bias = jnp.zeros((out_dim,)) return {"weight": weight, "bias": bias}def linear(params: dict, x: jax.Array) -> jax.Array: return x @ params["weight"].T + params["bias"]def branch_a(params: dict, x: jax.Array) -> jax.Array: return jax.nn.relu(linear(params, x))def branch_b(params: dict, x: jax.Array) -> jax.Array: return jax.nn.gelu(linear(params, x))def init_model(dim: int, key: jax.Array) -> dict: key_a, key_b = jax.random.split(key) return { "branch_a": init_linear(dim, dim, key_a), "branch_b": init_linear(dim, dim, key_b), }def forward(params: dict, x: jax.Array) -> jax.Array: pred = x.mean() > 0 # jax.lax.cond traces both branches but executes only one branch at # runtime. return jax.lax.cond( pred, lambda x_: branch_a(params["branch_a"], x_), lambda x_: branch_b(params["branch_b"], x_), x, )def select_device(device_type: str) -> tuple[jax.Device, str]: """Select device by type (gpu/tpu) with fallback and auto-detection. Returns: (device, device_name_str): JAX device object and display name """ device_type = device_type.lower() def try_get_devices(backend: str) -> list: """Safely attempt to get devices for a backend; return [] if unavailable.""" try: return jax.devices(backend) except RuntimeError: return [] if device_type == "auto": # Auto-detect: prefer TPU, fallback to GPU, then CPU tpu_devices = try_get_devices("tpu") gpu_devices = try_get_devices("gpu") if tpu_devices: return tpu_devices[0], "tpu" elif gpu_devices: return gpu_devices[0], "gpu" else: return jax.devices("cpu")[0], "cpu" elif device_type == "tpu": devices = try_get_devices("tpu") if not devices: raise RuntimeError("No TPU devices found. Use --device gpu or --device auto.") return devices[0], "tpu" elif device_type == "gpu": devices = try_get_devices("gpu") if not devices: raise RuntimeError("No GPU devices found. Use --device tpu or --device auto.") return devices[0], "gpu" elif device_type == "cpu": return jax.devices("cpu")[0], "cpu" else: raise ValueError(f"Unknown device type: {device_type}")if __name__ == "__main__": parser = argparse.ArgumentParser( description="JAX conditional execution with lax.cond", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="Default: auto-detect device (TPU > GPU > CPU)" ) parser.add_argument( "--device", type=str, default="auto", choices=["auto", "gpu", "tpu", "cpu"], help="Device to use (default: auto)", ) # Use parse_known_args to ignore extra arguments from Jupyter/Colab kernel args, _ = parser.parse_known_args() device, device_name = select_device(args.device) jax.config.update("jax_default_device", device) batch_size = 8 hidden_dim = 16 key = jax.random.PRNGKey(0) x_key, model_key = jax.random.split(key) x = jax.random.normal(x_key, (batch_size, hidden_dim)) params = init_model(hidden_dim, model_key) # Explicitly place all tensors on device and materialize them before profiling. x = jax.device_put(x, device) params = jax.tree.map(lambda leaf: jax.device_put(leaf, device), params) jax.block_until_ready((x, params)) print(f"=== Conditional Execution on {device_name.upper()} ===\n") print("=== Eager Execution ===") y = forward(params, x) print(f"y shape: {y.shape}, y[0]: {y[0]}\n") compiled_forward = jax.jit(forward) print("=== Compiled Execution ===") # Warm up. y = compiled_forward(params, x) y = compiled_forward(params, x) print(f"y shape: {y.shape}, y[0]: {y[0]}\n") print("=== Profiling ===") trace_dir = f"jax_profile_{device_name}" # Remove existing trace directory to always overwrite. shutil.rmtree(trace_dir, ignore_errors=True) # jax.profiler.trace() generates a TensorBoard/Perfetto-compatible trace. # View with: tensorboard --logdir=<trace_dir> # or upload the .json.gz files to https://ui.perfetto.dev with jax.profiler.trace(trace_dir): for _ in range(10): y = compiled_forward(params, x) jax.block_until_ready(y) print(f"✓ Exported trace to {trace_dir}/ (view with TensorBoard or Perfetto)")
|
Predicated Execution VS Conditional Execution
In the above examples, the PredicatedModel
and the ConditionalModel
produce the same output for the same input. The difference is that the PredicatedModel
executes both branches and then selects the output based on the condition, while the ConditionalModel
executes only one branch based on the condition at a cost of one host-device synchronization.
Consequently, selecting between predicated execution and conditional execution depends on the specific use case. If both branches are lightweight, which introduces minimal overhead to the system, we should consider using predicated execution. However, if at least one branch is heavy and the negative impact of host-device synchronization to the system weighs less than the overhead of executing both branches, we should consider using conditional execution.
Conditional Execution Kernel Fusion Optimization
Previously, I have been wondering why conditional execution cannot be optimized so that the host-device synchronization can be avoided. For example, a neural network compiler technically can see the code of condition and two branches and generate one kernel that performs the conditional execution on the GPU without host-device synchronization.
In the following example, I have implemented a CUDA kernel that performs conditional execution on the GPU without host-device synchronization. The kernel first computes the condition and then executes one of the two branches based on the computed condition. The scalar predicate is kept on the device and avoids host-side dynamic branch dispatch.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
|
#include <algorithm>#include <cooperative_groups.h>#include <cstdio>#include <cstdlib>#include <cuda_runtime.h>#include <iostream>namespace cg = cooperative_groups;#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)void check(cudaError_t err, char const* func, char const* file, int line){ if (err != cudaSuccess) { std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; std::cerr << cudaGetErrorString(err) << " " << func << std::endl; std::exit(EXIT_FAILURE); }}#define CHECK_LAST_CUDA_ERROR() check_last(__FILE__, __LINE__)void check_last(char const* file, int line){ cudaError_t const err{cudaGetLastError()}; if (err != cudaSuccess) { std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl; std::cerr << cudaGetErrorString(err) << std::endl; std::exit(EXIT_FAILURE); }}__device__ __forceinline__ bool compute_condition(float sum, int n){ return (sum / static_cast<float>(n)) > 0.0f;}__device__ __forceinline__ float branch_true(float x){ // Example branch: y = 2x + 1 return 2.0f * x + 1.0f;}__device__ __forceinline__ float branch_false(float x){ // Example branch: y = x^2 - 1 return x * x - 1.0f;}template <size_t NUM_THREADS>__global__ voidconditional_kernel(float const* __restrict__ x, float* __restrict__ y, size_t n, float* __restrict__ block_sums, bool* __restrict__ pred_ptr){ cg::grid_group grid = cg::this_grid(); size_t const tid{threadIdx.x}; size_t const gtid{blockIdx.x * NUM_THREADS + tid}; size_t const stride{NUM_THREADS * gridDim.x}; __shared__ float sdata[NUM_THREADS]; float local_sum = 0.0f; for (size_t i{gtid}; i < n; i += stride) { local_sum += x[i]; } sdata[tid] = local_sum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); } if (tid == 0) { block_sums[blockIdx.x] = sdata[0]; } grid.sync(); // One block reduces partial sums to a scalar predicate for the whole grid. if (blockIdx.x == 0) { float partial = 0.0f; for (size_t b{tid}; b < gridDim.x; b += NUM_THREADS) { partial += block_sums[b]; } sdata[tid] = partial; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); } if (tid == 0) { *pred_ptr = compute_condition(sdata[0], n); } } grid.sync(); bool const pred{*pred_ptr}; for (size_t idx{gtid}; idx < n; idx += stride) { float const xi{x[idx]}; y[idx] = pred ? branch_true(xi) : branch_false(xi); }}int main(){ size_t n{8}; constexpr size_t NUM_THREADS{256}; float h_x[n] = {-2.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 2.0f, 3.0f}; float h_y[n] = {0.0f}; float* d_x = nullptr; float* d_y = nullptr; float* d_block_sums = nullptr; bool* d_pred = nullptr; cudaDeviceProp prop{}; CHECK_CUDA_ERROR(cudaGetDeviceProperties(&prop, 0)); if (!prop.cooperativeLaunch) { std::printf("Device does not support cooperative kernel launch.\n"); return 0; } CHECK_CUDA_ERROR(cudaMalloc(&d_x, n * sizeof(float))); CHECK_CUDA_ERROR(cudaMalloc(&d_y, n * sizeof(float))); CHECK_CUDA_ERROR( cudaMemcpy(d_x, h_x, n * sizeof(float), cudaMemcpyHostToDevice)); int blocks_per_sm = 0; CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &blocks_per_sm, conditional_kernel<NUM_THREADS>, NUM_THREADS, 0)); int const max_blocks{blocks_per_sm * prop.multiProcessorCount}; int const blocks{std::max( 1, std::min(static_cast<int>((n + NUM_THREADS - 1) / NUM_THREADS), max_blocks))}; CHECK_CUDA_ERROR(cudaMalloc(&d_block_sums, blocks * sizeof(float))); CHECK_CUDA_ERROR(cudaMalloc(&d_pred, sizeof(bool))); void* args[] = {&d_x, &d_y, &n, &d_block_sums, &d_pred}; CHECK_CUDA_ERROR(cudaLaunchCooperativeKernel( reinterpret_cast<void*>(conditional_kernel<NUM_THREADS>), blocks, NUM_THREADS, args, 0, 0)); CHECK_LAST_CUDA_ERROR(); CHECK_CUDA_ERROR( cudaMemcpy(h_y, d_y, n * sizeof(float), cudaMemcpyDeviceToHost)); CHECK_CUDA_ERROR(cudaDeviceSynchronize()); for (size_t i{0}; i < n; ++i) { std::printf("x=%6.2f -> y=%6.2f\n", h_x[i], h_y[i]); } CHECK_CUDA_ERROR(cudaFree(d_x)); CHECK_CUDA_ERROR(cudaFree(d_y)); CHECK_CUDA_ERROR(cudaFree(d_block_sums)); CHECK_CUDA_ERROR(cudaFree(d_pred)); return 0;}
|
Even though this seems to be a good solution, it is hard to generalize this approach to all scenarios, because the conditions and branches can be arbitrarily complex, and the compiler may not be able to generate a single kernel that handles all cases.
When the conditions and branches are simple, the compiler might be able to generate a single kernel that performs the conditional execution on the GPU without host-device synchronization, as demonstrated in the above example. However, its performance might only just be slightly better than the predicated execution approach, by saving two kernel launch overheads and one branch execution. When the predicated execution approach is also being optimized by the compiler, depending on the branch instructions, it is possible to have horizontal fusion for the two branches, which makes the performance difference between the two approaches even smaller.
Because of these, it is not worth the effort to implement a general solution for conditional execution fusion kernel optimization. Instead, the user will have to choose between predicated execution and conditional execution based on the specific use case, as discussed in the previous section.
References
Predicated Execution VS Conditional Execution
https://leimao.github.io/blog/Predicated-Execution-VS-Conditional-Execution/