{"slug": "why-tpus-aren-t-popular-even-though-they-re-cheaper-per-token", "title": "Why TPUs Aren't Popular (Even Though They're Cheaper Per Token)", "summary": "NVIDIA GPUs handle variable-length inference requests dynamically without recompilation, while TPUs and AWS Trainium require fixed shapes compiled ahead of time, causing crashes or stalls on mismatched inputs. This architectural constraint forces TPU users to pad sequences and waste compute, whereas NVIDIA's SIMT design allows seamless concatenation of different-length requests. The static/dynamic split explains why cheaper-per-token TPUs remain niche despite their theoretical advantages.", "body_md": "If you only look at the spec sheet, the TPU story is overwhelming: lower cost-per-token, dramatically better watts-per-token, deterministic latency. Trainium tells the same story. And yet most of the industry — including the inference traffic behind ChatGPT and Claude's web UI — still runs on NVIDIA. The gap between \"cheaper on paper\" and \"what people actually deploy\" is not a marketing failure. It's an architectural tax that systolic-array silicon charges you in code, pipelines, and org structure. This post is about where that tax comes from and why only a handful of companies can afford to pay it.\n\nNVIDIA GPUs are SIMT (Single Instruction, Multiple Threads) processors. They schedule threads dynamically at runtime and page memory on demand. TPUs and AWS Trainium are not GPUs — they are **systolic arrays**: a grid of multiply-accumulate units wired directly to their neighbors, fed by an ahead-of-time compiler (XLA for TPU, the Neuron compiler for Trainium).\n\nA systolic array hits peak utilization only when the shape of the data flowing through it is **fixed at compile time**. Weights are loaded once and stay stationary in the processing elements; activations slide through like a bucket brigade. Change the sequence length or batch size by even one token and the data routes and memory addresses have to be recomputed — which means the compiler has to generate a *new binary*.\n\nThat single constraint is the source of every downstream pain. Here's what it forces on you at inference time:\n\n| Runtime input | NVIDIA (dynamic) | TPU / Trainium (static) |\n|---|---|---|\n| Larger than the compiled bucket | Handled by dynamic allocation | Shape-mismatch crash |\n| Smaller than the bucket | Handled with no waste | JIT recompile stall (minutes) or zero-pad waste |\n| New, unseen length | Just runs | New binary must exist, or it stalls |\n\nSo before any token reaches the chip, you need an answer to: \"what shape is this, and which precompiled binary does it route to?\" On NVIDIA you never ask that question.\n\nThe cleanest mental model: **NVIDIA is Python, TPU/Trainium is Java.**\n\n`forward`\n\nand it just works, \"good enough\" fast, with no compile step in your face.`NEFF`\n\nfor Neuron, an XLA executable for TPU). In exchange for boilerplate and rigid discipline, you get extreme execution efficiency — once everything fits the contract.AMD's Instinct line (CDNA, ROCm) sits firmly on the **NVIDIA/Python side**: SIMT, dynamic shapes, `PagedAttention`\n\nsupport, and a `HIPIFY`\n\ntoolchain whose entire purpose is to run your existing CUDA code unchanged. The static/dynamic split is the real fault line — not the vendor logos.\n\nSuppose three users hit your endpoint at once: 3,000 / 4,000 / 1,000 tokens. On NVIDIA you don't pad and you don't build a mask. You concatenate them into one flat 8,000-token buffer and hand `FlashAttention`\n\na `cu_seqlens`\n\nindex marking the boundaries:\n\n```\n# NVIDIA: variable-length attention. No padding, no mask matrix.\n# Just a flat buffer + cumulative sequence lengths [0, 3000, 7000, 8000].\noutputs = flash_attn_varlen_func(\n    q, k, v,\n    cu_seqlens_q, cu_seqlens_k,\n    max_seqlen_q, max_seqlen_k,\n)\n```\n\nThe kernel reads the boundary index and isolates each user's context in hardware. No wasted FLOPs on cross-user attention. The code is \"just the model logic.\"\n\nOn a TPU you can't reshape the systolic array, so you do the opposite: force everything into one fixed `[batch, STATIC_SEQ_LEN]`\n\nrectangle and use math to erase the parts you don't want computed.\n\n``` python\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch_xla.core.xla_model as xm\n\nclass StaticShapeAttention(nn.Module):\n    def __init__(self, d_model, n_heads):\n        super().__init__()\n        self.n_heads, self.d_k = n_heads, d_model // n_heads\n        self.q = nn.Linear(d_model, d_model)\n        self.k = nn.Linear(d_model, d_model)\n        self.v = nn.Linear(d_model, d_model)\n        self.out = nn.Linear(d_model, d_model)\n\n    def forward(self, x, attention_mask):\n        # x is ALWAYS [batch, STATIC_SEQ_LEN, d_model]. The shape never varies.\n        b, s, _ = x.size()\n        q = self.q(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)\n        k = self.k(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)\n        v = self.v(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)\n\n        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)\n\n        # The systolic array DID compute every cell, including padding and\n        # other users' regions. We retroactively delete them: e^(-1e9) -> 0.\n        scores = scores.masked_fill(attention_mask == 0, -1e9)\n        attn = F.softmax(scores, dim=-1)\n\n        ctx = torch.matmul(attn, v).transpose(1, 2).contiguous().view(b, s, -1)\n        return self.out(ctx)\n```\n\nTwo things in that snippet are pure consequences of static silicon:\n\n`xm.mark_step()`\n\nis the real execution trigger.`model(x)`\n\non XLA only `mark_step()`\n\ncompiles the accumulated graph into one fixed binary and ships it. New shape → new compile.`masked_fill(..., -1e9)`\n\nis a hack, not an optimization.`varlen`\n\npath The crash-on-overflow case is intuitive: feed 1,025 tokens into a binary compiled for 1,024 and you get a shape mismatch. The nastier case is *underflow* — a 100-token request hitting a 1,024 system:\n\n`0 × 0 + 0`\n\nacross ~90% of its cells, consuming full power to compute nothing. Utilization collapses.The escape hatch is **packing**: instead of one user per bucket, tile multiple users' requests into a fixed rectangle like Tetris, and generate a segment-ID mask so attention can't bleed across users.\n\n```\nFixed bucket [ 8192 tokens ]\n├─ User A query (3000)\n├─ User B query (4000)\n├─ User C query (1000)\n└─ padding      (192)   <-- the only waste\n```\n\nIt helps to be concrete about what \"the rectangle\" physically is. When you compile with `BATCH_SIZE = 4, STATIC_SEQ_LEN = 8192`\n\n, XLA reserves **one contiguous [4, 8192] static region** in the TPU's HBM — not four independent \"rooms,\" but one big sheet the compiler hard-wires the array routes for. A single user rarely fills even one 8,192 lane, so the serving layer packs\n\n```\n[ One TPU processor: one static [4 x 8192] sheet ]\n\nlane[0] (8192): [ A(2000) + B(5000) + C(1000) + pad(192) ]\nlane[1] (8192): [ D(8000)                      + pad(192) ]\nlane[2] (8192): [ E(3000) + F(3000) + G(2100)  + pad(92)  ]\nlane[3] (8192): [ H(4000) + I(4000)            + pad(192) ]\n```\n\nPhysically there are 4 lanes (32K of space); logically the proxy just crammed **9 ragged users (A–I)** into them. From the application side it looks like one TPU is concurrently servicing many small requests in parallel — but underneath it's one rigid sheet with a segment mask drawn over it. The reason the hardware wants one fat sheet instead of pre-carved small rooms is pure systolic-array physics: the bigger the matrix, the higher the array's fill rate and the fewer idle cycles between feeds.\n\nDone right, MFU (Model FLOPs Utilization) approaches 100%. But notice what you just built: a high-throughput Go/C++ proxy in front of the cluster whose only job is to catch ragged input and pack it into rectangles in real time. On NVIDIA, that entire layer **does not exist**.\n\nPeople assume `torch_xla`\n\nabstracts the hardware away because `xm.xla_device()`\n\ntransparently targets both TPU and Trainium (thanks to the shared OpenXLA/PJRT runtime — `libtpu.so`\n\nfor TPU, `libneuronpjrt.so`\n\nfor Neuron). That's true for `model.to(device)`\n\nand basic ops. It is emphatically *not* true for the parts that matter.\n\nThe `forward`\n\nsignature itself diverges:\n\n```\n# NVIDIA forward: ragged data + boundary index. Length is arbitrary every call.\ndef forward(self, input_ids, cu_seqlens, max_seqlen):\n    return self.flash_attn_func(input_ids, cu_seqlens, max_seqlen)\n\n# Static forward: fixed rectangle + a mask matrix you must build yourself.\ndef forward(self, input_ids, attention_mask):  # input_ids is [batch, FixedSeqLen]\n    return self.static_attn_func(input_ids, attention_mask)\n```\n\nAnd it cascades all the way down:\n\n| Component | NVIDIA pipeline | Trainium pipeline |\n|---|---|---|\n| Inference engine |\n`vLLM` (CUDA), `TensorRT-LLM`\n|\n`NxD` / `vllm-neuron`\n|\n| Custom kernels | Triton, CUDA C++ (`FlashAttention` ) |\nNKI (Neuron Kernel Interface), rewritten from scratch |\n| Base image | `nvcr.io/nvidia/pytorch` |\nAWS Neuron DLC |\n| CI build artifact | weights + CUDA/Triton binaries | weights + NEFF static binaries per bucket\n|\n| Deploy target |\n`g5` / `p5` instances |\n`trn1` / `inf2` instances |\n| Monitoring |\n`nvidia-smi` , DCGM exporter |\n`neuron-top` , Neuron exporter |\n\nTwo completely parallel worlds. Your CUDA container, your eval scripts, your autoscaling triggers — none of it carries over. vLLM's hardware-plugin mechanism gives you \"one skin\" at the business-logic layer, but the engine underneath is 100% separate code with separate bugs.\n\nThe data-type story isn't symmetric either. BF16 (which Google's TPU pioneered) is stable on both sides — its FP32-range exponent survives the `-1e9`\n\nmask values without going NaN. But FP8, the current throughput play, favors NVIDIA: FP8 attention scores swing hard and need **dynamic scaling** at runtime to avoid clipping. A static compiler has to bake in a fixed scale factor at compile time, so on TPU/Trainium aggressive FP8 attention risks clipping that degrades model quality. \"Just switch to FP8\" is a one-liner on NVIDIA and a research project on static silicon.\n\nThis is the part that kills adoption and nobody puts on a slide. On NVIDIA there's a clean abstraction boundary:\n\n```\n[ AI engineer / data scientist ]\n   architecture, hyperparams, eval\n        │\n        ▼  boundary: Hugging Face weights / standard PyTorch\n        │\n[ MLOps / LLMOps engineer ]\n   drop into vLLM, configure PagedAttention, scale out\n```\n\nThe data scientist never thinks about memory layout. The MLOps engineer never reads the attention math. They ship artifacts across a clean interface.\n\nOn TPU that wall **disappears**, because model structure is directly coupled to physical constraints:\n\n`forward`\n\n(AI engineer) are two halves of one design. Change the batching strategy and the math has to change in lockstep. You cannot split that across a spec doc.`if`\n\nbranch or changing layer count alters the compiled graph topology — and triggers JIT stalls or OOM in production. Debugging that requires dumping the XLA HLO graph, which pulls the AI engineer into an \"infra\" incident.The organizations that succeed on TPU — Google's Gemini team, Anthropic's Claude team, Meta's Llama-on-TPU group — abandoned the horizontal \"data science dept / infra dept\" split entirely. They run a single vertically-integrated team of people fluent in *both* the attention math and the compiler internals. Most companies cannot staff that, and the projects that try to keep the old division of labor die in a pile of compile errors and OOMs.\n\nThe whole calculus flips when **you control the input channel** so the shapes are predictable. Two clean examples:\n\n`neuronx-distributed`\n\n, with a Go/C++ proxy doing real-time packing), and Claude Code is — read cynically — the perfect input-locking channel that makes a Java-style chip worth the pain. Long-context workloads help too: a 200K-token prefill fills a 32K bucket with ~zero padding, so the static array's weakness evaporates exactly where Claude is strongest.The inverse is just as logical, and it explains why the *chat* UIs stay on NVIDIA. ChatGPT and Claude.ai's web frontends accept arbitrary text, surprise image uploads, and topic switches mid-conversation. The system can't predict the shape until the user hits send. That chaos is precisely what dynamic SIMT + `PagedAttention`\n\nwere built for.\n\nThe spec sheet was never lying about cost-per-token. It just wasn't pricing in the engineers, the forked pipeline, and the org redesign you have to buy first.", "url": "https://wpnews.pro/news/why-tpus-aren-t-popular-even-though-they-re-cheaper-per-token", "canonical_source": "https://dev.to/toyama0919/why-tpus-arent-popular-even-though-theyre-cheaper-per-token-188g", "published_at": "2026-06-05 08:35:50+00:00", "updated_at": "2026-06-05 08:42:37.460670+00:00", "lang": "en", "topics": ["ai-chips", "ai-infrastructure", "machine-learning", "large-language-models", "artificial-intelligence"], "entities": ["TPU", "AWS Trainium", "NVIDIA", "ChatGPT", "Claude", "XLA", "Neuron compiler", "SIMT"], "alternates": {"html": "https://wpnews.pro/news/why-tpus-aren-t-popular-even-though-they-re-cheaper-per-token", "markdown": "https://wpnews.pro/news/why-tpus-aren-t-popular-even-though-they-re-cheaper-per-token.md", "text": "https://wpnews.pro/news/why-tpus-aren-t-popular-even-though-they-re-cheaper-per-token.txt", "jsonld": "https://wpnews.pro/news/why-tpus-aren-t-popular-even-though-they-re-cheaper-per-token.jsonld"}}