{"slug": "why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch", "title": "Why JAX Is a Much Better Backend for Quantum Circuit Simulation Than PyTorch", "summary": "A developer benchmarked quantum circuit simulation backends using a 20-qubit VQE workload on an NVIDIA RTX 5090 GPU, finding JAX/XLA 12.4x faster than PyTorch and 15.7x faster than TorchQuantum for post-compilation value-and-gradient steps. The JAX backend required a 53-second compile/warmup phase versus PyTorch's 0.48 seconds, but delivered 0.0265-second runtime compared to PyTorch's 0.3299 seconds. The results demonstrate that JAX's aggressive whole-program compilation through XLA provides a significant performance advantage for quantum simulation workloads involving irregular tensor contractions and automatic differentiation.", "body_md": "Modern quantum circuit simulation is not just “machine learning with complex tensors.” It involves irregular tensor contractions, sparse operators, statevector transformations, and automatic differentiation through all of them. This makes backend choice unusually important. A backend that is excellent for standard neural-network layers may still be a poor fit for general quantum simulation workloads.\n\nWe benchmarked this with a simple VQE workload for the 1D transverse-field Ising\n\nmodel as in [the script](https://github.com/tensorcircuit/tensorcircuit-ng/blob/master/examples/benchmark_jax_vs_torch_tfim.py),\n\n```\nH = -sum_i Z_i Z_{i+1} - sum_i X_i,\n```\n\nusing 20 qubits, 10 ansatz layers, complex64 precision, and one NVIDIA RTX 5090 GPU.\n\n| Backend | Compile / Warmup | Value+Grad Runtime |\n|---|---|---|\n| TensorCircuit-NG, JAX backend | 53.53 s | 0.0265 s |\n| TensorCircuit-NG, PyTorch backend | 0.48 s | 0.3299 s |\n| TorchQuantum, optimized implementation than default | 0.81 s | 0.4172 s |\n\nThe JAX backend is about **12.4x faster** than TensorCircuit-NG’s PyTorch backend and about **15.7x faster** than TorchQuantum for the post-compilation value-and-gradient step.\n\nThe compile time tells the other half of the story: JAX pays a much larger upfront XLA compilation cost. But after compilation, XLA produces a far more effective execution plan for this quantum simulation workload. This is exactly the tradeoff we want in VQE, QAOA, time evolution, and many other iterative algorithms: pay once, run many times.\n\nQuantum circuit simulation stresses a backend differently from ordinary deep learning. The workload mixes tensor-network contraction, sparse Hamiltonian application, and reverse-mode differentiation. JAX/XLA is designed to see the whole computation and optimize it aggressively as a compiled program on the target device.\n\nPyTorch, in contrast, is strongest where the workload resembles standard neural network layers. For more general tensor programs, especially tensor-network-like simulation code, the compiler stack is less aggressive and less predictable.\n\nIn this benchmark, the same TensorCircuit-NG algorithm is more than an order of magnitude faster on JAX than on PyTorch after compilation.\n\nWe also compared against TorchQuantum as a representative PyTorch-native quantum circuit package. To make the comparison generous, we did not use its generic Pauli-string expectation path. That built-in route tends to materialize dense Pauli operators and is slow and not scalable. Instead, we implemented a TFIM-specific expectation directly extracted from state:\n\n`ZZ`\n\nterms are evaluated from probabilities and precomputed sign tensors.`X`\n\nterms are evaluated by flipping the state axis and taking an inner product.This is already a substantial low-level optimization Even with that help, TorchQuantum remains slower than TensorCircuit-NG on the JAX backend by about 15.7x. And even if you prefer PyTorch backend, PyTorch backend from TensorCircuit-NG is still a better choice in terms of both warm-up and run times.\n\nThe lesson is not merely that one package is faster than another. The deeper point is that backend architecture matters. Quantum simulation benefits from a compiler that can optimize a whole differentiable tensor program, not just a collection of familiar machine-learning layers.\n\nFor TensorCircuit-NG, the JAX backend gives exactly that: a high-level quantum programming interface backed by XLA’s aggressive compilation. The result is a backend that is not only elegant for research code, but also dramatically faster for real differentiable quantum simulation workloads.", "url": "https://wpnews.pro/news/why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch", "canonical_source": "https://dev.to/refractionray/why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch-ak6", "published_at": "2026-06-06 05:01:36+00:00", "updated_at": "2026-06-06 05:11:57.755321+00:00", "lang": "en", "topics": ["machine-learning", "neural-networks", "ai-research", "ai-infrastructure", "ai-tools"], "entities": ["JAX", "PyTorch", "TensorCircuit-NG", "TorchQuantum", "NVIDIA", "XLA", "RTX 5090", "VQE"], "alternates": {"html": "https://wpnews.pro/news/why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch", "markdown": "https://wpnews.pro/news/why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch.md", "text": "https://wpnews.pro/news/why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch.txt", "jsonld": "https://wpnews.pro/news/why-jax-is-a-much-better-backend-for-quantum-circuit-simulation-than-pytorch.jsonld"}}