Predicated Execution VS Conditional Execution Predicated execution and conditional execution are two approaches to handling control flow in programming, with predicated execution running all instructions but committing only those meeting a condition, while conditional execution runs only the branch that meets the condition. PyTorch's torch.where implements predicated execution and torch.cond implements conditional execution, each with different performance trade-offs in GPU programming. Predicated Execution VS Conditional Execution Introduction Predicated execution and conditional execution are two different approaches to handling control flow in programming, particularly in the context of parallel computing and GPU programming. Predicated execution involves executing all instructions but only committing the results of those instructions that meet a certain condition, while conditional execution involves executing instructions only if a certain condition is met. In this blog post, I would like to quickly show a few examples of predicated execution and conditional execution using PyTorch and discuss how to choose between them in different scenarios. Predicated Execution A common PyTorch API that uses predicated execution is torch.where . The torch.where function takes a condition tensor and two other tensors, and it returns a new tensor where each element is selected from one of the two input tensors based on the corresponding value in the condition tensor. In the following example, we have two branches of neural network which produces outputs of exactly the same metadata shape, dtype, device . We can use torch.where to select the output from one of the two branches based on a condition tensor. However, both branches are executed regardless of the condition. 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 | python import torchimport torch.nn as nnclass BranchA nn.Module : def init self, dim: int - None: super . init self.net = nn.Sequential nn.Linear dim, dim , nn.ReLU def forward self, x: torch.Tensor - torch.Tensor: return self.net x class BranchB nn.Module : def init self, dim: int - None: super . init self.net = nn.Sequential nn.Linear dim, dim , nn.GELU def forward self, x: torch.Tensor - torch.Tensor: return self.net x class PredicatedModel nn.Module : def init self, dim: int - None: super . init self.branch a = BranchA dim self.branch b = BranchB dim def forward self, x: torch.Tensor - torch.Tensor: Both branches are executed and produce tensors with identical metadata. y a = self.branch a x y b = self.branch b x Use the same scalar predicate as the conditional-execution examples. pred = x.mean 0 Predicated execution: both branches run, then one full output is selected. y = torch.where pred, y a, y b return yif name == " main ": torch.manual seed 0 device = "cuda" batch size = 8 hidden dim = 16 x = torch.randn batch size, hidden dim, device=device with torch.device device : model = PredicatedModel hidden dim print "=== Predicated Execution ===" print "=== Eager Execution ===" torch.cuda.set sync debug mode debug mode="warn" y = model x torch.cuda.set sync debug mode debug mode="default" compiled model = torch.compile model print "=== Compiled Execution ===" Warm up y = compiled model x torch.cuda.set sync debug mode debug mode="warn" y = compiled model x torch.cuda.set sync debug mode debug mode="default" | Conditional Execution In contrast, conditional execution involves executing only the branch of code that meets a certain condition. In this example, we use torch.cond to achieve conditional execution. The torch.cond function takes a condition tensor and two functions, and it executes only the function corresponding to the value of the condition tensor. 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 | python import torchimport torch.nn as nnclass BranchA nn.Module : def init self, dim: int - None: super . init self.net = nn.Sequential nn.Linear dim, dim , nn.ReLU def forward self, x: torch.Tensor - torch.Tensor: return self.net x class BranchB nn.Module : def init self, dim: int - None: super . init self.net = nn.Sequential nn.Linear dim, dim , nn.GELU def forward self, x: torch.Tensor - torch.Tensor: return self.net x class ConditionalModel nn.Module : def init self, dim: int - None: super . init self.branch a = BranchA dim self.branch b = BranchB dim def forward self, x: torch.Tensor - torch.Tensor: pred = x.mean 0 def true fn x : torch.Tensor - torch.Tensor: return self.branch a x def false fn x : torch.Tensor - torch.Tensor: return self.branch b x torch.cond traces both branches but executes only one branch at runtime. return torch.cond pred, true fn, false fn, x, if name == " main ": torch.manual seed 0 torch.cuda.set sync debug mode debug mode="warn" device = "cuda" batch size = 8 hidden dim = 16 x = torch.randn batch size, hidden dim, device=device with torch.device device : model = ConditionalModel hidden dim print "=== Conditional Execution ===" print "=== Eager Execution ===" torch.cuda.set sync debug mode debug mode="warn" y = model x torch.cuda.set sync debug mode debug mode="default" compiled model = torch.compile model print "=== Compiled Execution ===" Warm up y = compiled model x torch.cuda.set sync debug mode debug mode="warn" y = compiled model x torch.cuda.set sync debug mode debug mode="default" | One caveat of torch.cond is that it will result in host-device synchronization because the instructions in the two branches are dynamically dispatched to the GPU at runtime. Changing the execution framework will not eliminate this synchronization. Even if it is compiled by torch.compile , AOTInductor or TensorRT running on GPU or it is jax.lax.cond in JAX and compiled by XLA running on both GPU and TPU v5e , verified by experiments not discussed here, it will still always result in host-device synchronization. To run conditional execution using jax.lax.cond , we could use the following script on GPU or TPU platforms. 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 | """Unified JAX conditional execution script for GPU and TPU.Usage: python conditional execution jax.py Auto-detect device python conditional execution jax.py --device gpu Force GPU python conditional execution jax.py --device tpu Force TPU"""import argparseimport osimport shutilimport sys Set GPU memory fraction only on GPU devices no-op on TPU .os.environ "XLA PYTHON CLIENT MEM FRACTION" = "0.2"import jaximport jax.numpy as jnpdef init linear in dim: int, out dim: int, key: jax.Array - dict: Kaiming uniform scale to match PyTorch nn.Linear default. scale = jnp.sqrt 2.0 / in dim weight = jax.random.normal key, out dim, in dim scale bias = jnp.zeros out dim, return {"weight": weight, "bias": bias}def linear params: dict, x: jax.Array - jax.Array: return x @ params "weight" .T + params "bias" def branch a params: dict, x: jax.Array - jax.Array: return jax.nn.relu linear params, x def branch b params: dict, x: jax.Array - jax.Array: return jax.nn.gelu linear params, x def init model dim: int, key: jax.Array - dict: key a, key b = jax.random.split key return { "branch a": init linear dim, dim, key a , "branch b": init linear dim, dim, key b , }def forward params: dict, x: jax.Array - jax.Array: pred = x.mean 0 jax.lax.cond traces both branches but executes only one branch at runtime. return jax.lax.cond pred, lambda x : branch a params "branch a" , x , lambda x : branch b params "branch b" , x , x, def select device device type: str - tuple jax.Device, str : """Select device by type gpu/tpu with fallback and auto-detection. Returns: device, device name str : JAX device object and display name """ device type = device type.lower def try get devices backend: str - list: """Safely attempt to get devices for a backend; return if unavailable.""" try: return jax.devices backend except RuntimeError: return if device type == "auto": Auto-detect: prefer TPU, fallback to GPU, then CPU tpu devices = try get devices "tpu" gpu devices = try get devices "gpu" if tpu devices: return tpu devices 0 , "tpu" elif gpu devices: return gpu devices 0 , "gpu" else: return jax.devices "cpu" 0 , "cpu" elif device type == "tpu": devices = try get devices "tpu" if not devices: raise RuntimeError "No TPU devices found. Use --device gpu or --device auto." return devices 0 , "tpu" elif device type == "gpu": devices = try get devices "gpu" if not devices: raise RuntimeError "No GPU devices found. Use --device tpu or --device auto." return devices 0 , "gpu" elif device type == "cpu": return jax.devices "cpu" 0 , "cpu" else: raise ValueError f"Unknown device type: {device type}" if name == " main ": parser = argparse.ArgumentParser description="JAX conditional execution with lax.cond", formatter class=argparse.RawDescriptionHelpFormatter, epilog="Default: auto-detect device TPU GPU CPU " parser.add argument "--device", type=str, default="auto", choices= "auto", "gpu", "tpu", "cpu" , help="Device to use default: auto ", Use parse known args to ignore extra arguments from Jupyter/Colab kernel args, = parser.parse known args device, device name = select device args.device jax.config.update "jax default device", device batch size = 8 hidden dim = 16 key = jax.random.PRNGKey 0 x key, model key = jax.random.split key x = jax.random.normal x key, batch size, hidden dim params = init model hidden dim, model key Explicitly place all tensors on device and materialize them before profiling. x = jax.device put x, device params = jax.tree.map lambda leaf: jax.device put leaf, device , params jax.block until ready x, params print f"=== Conditional Execution on {device name.upper } ===\n" print "=== Eager Execution ===" y = forward params, x print f"y shape: {y.shape}, y 0 : {y 0 }\n" compiled forward = jax.jit forward print "=== Compiled Execution ===" Warm up. y = compiled forward params, x y = compiled forward params, x print f"y shape: {y.shape}, y 0 : {y 0 }\n" print "=== Profiling ===" trace dir = f"jax profile {device name}" Remove existing trace directory to always overwrite. shutil.rmtree trace dir, ignore errors=True jax.profiler.trace generates a TensorBoard/Perfetto-compatible trace. View with: tensorboard --logdir=