# Why TPUs Aren't Popular (Even Though They're Cheaper Per Token)

> Source: <https://dev.to/toyama0919/why-tpus-arent-popular-even-though-theyre-cheaper-per-token-188g>
> Published: 2026-06-05 08:35:50+00:00

If you only look at the spec sheet, the TPU story is overwhelming: lower cost-per-token, dramatically better watts-per-token, deterministic latency. Trainium tells the same story. And yet most of the industry — including the inference traffic behind ChatGPT and Claude's web UI — still runs on NVIDIA. The gap between "cheaper on paper" and "what people actually deploy" is not a marketing failure. It's an architectural tax that systolic-array silicon charges you in code, pipelines, and org structure. This post is about where that tax comes from and why only a handful of companies can afford to pay it.

NVIDIA GPUs are SIMT (Single Instruction, Multiple Threads) processors. They schedule threads dynamically at runtime and page memory on demand. TPUs and AWS Trainium are not GPUs — they are **systolic arrays**: a grid of multiply-accumulate units wired directly to their neighbors, fed by an ahead-of-time compiler (XLA for TPU, the Neuron compiler for Trainium).

A systolic array hits peak utilization only when the shape of the data flowing through it is **fixed at compile time**. Weights are loaded once and stay stationary in the processing elements; activations slide through like a bucket brigade. Change the sequence length or batch size by even one token and the data routes and memory addresses have to be recomputed — which means the compiler has to generate a *new binary*.

That single constraint is the source of every downstream pain. Here's what it forces on you at inference time:

| Runtime input | NVIDIA (dynamic) | TPU / Trainium (static) |
|---|---|---|
| Larger than the compiled bucket | Handled by dynamic allocation | Shape-mismatch crash |
| Smaller than the bucket | Handled with no waste | JIT recompile stall (minutes) or zero-pad waste |
| New, unseen length | Just runs | New binary must exist, or it stalls |

So before any token reaches the chip, you need an answer to: "what shape is this, and which precompiled binary does it route to?" On NVIDIA you never ask that question.

The cleanest mental model: **NVIDIA is Python, TPU/Trainium is Java.**

`forward`

and it just works, "good enough" fast, with no compile step in your face.`NEFF`

for Neuron, an XLA executable for TPU). In exchange for boilerplate and rigid discipline, you get extreme execution efficiency — once everything fits the contract.AMD's Instinct line (CDNA, ROCm) sits firmly on the **NVIDIA/Python side**: SIMT, dynamic shapes, `PagedAttention`

support, and a `HIPIFY`

toolchain whose entire purpose is to run your existing CUDA code unchanged. The static/dynamic split is the real fault line — not the vendor logos.

Suppose three users hit your endpoint at once: 3,000 / 4,000 / 1,000 tokens. On NVIDIA you don't pad and you don't build a mask. You concatenate them into one flat 8,000-token buffer and hand `FlashAttention`

a `cu_seqlens`

index marking the boundaries:

```
# NVIDIA: variable-length attention. No padding, no mask matrix.
# Just a flat buffer + cumulative sequence lengths [0, 3000, 7000, 8000].
outputs = flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q, max_seqlen_k,
)
```

The kernel reads the boundary index and isolates each user's context in hardware. No wasted FLOPs on cross-user attention. The code is "just the model logic."

On a TPU you can't reshape the systolic array, so you do the opposite: force everything into one fixed `[batch, STATIC_SEQ_LEN]`

rectangle and use math to erase the parts you don't want computed.

``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm

class StaticShapeAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads, self.d_k = n_heads, d_model // n_heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, attention_mask):
        # x is ALWAYS [batch, STATIC_SEQ_LEN, d_model]. The shape never varies.
        b, s, _ = x.size()
        q = self.q(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
        k = self.k(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
        v = self.v(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)

        # The systolic array DID compute every cell, including padding and
        # other users' regions. We retroactively delete them: e^(-1e9) -> 0.
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)

        ctx = torch.matmul(attn, v).transpose(1, 2).contiguous().view(b, s, -1)
        return self.out(ctx)
```

Two things in that snippet are pure consequences of static silicon:

`xm.mark_step()`

is the real execution trigger.`model(x)`

on XLA only `mark_step()`

compiles the accumulated graph into one fixed binary and ships it. New shape → new compile.`masked_fill(..., -1e9)`

is a hack, not an optimization.`varlen`

path The crash-on-overflow case is intuitive: feed 1,025 tokens into a binary compiled for 1,024 and you get a shape mismatch. The nastier case is *underflow* — a 100-token request hitting a 1,024 system:

`0 × 0 + 0`

across ~90% of its cells, consuming full power to compute nothing. Utilization collapses.The escape hatch is **packing**: instead of one user per bucket, tile multiple users' requests into a fixed rectangle like Tetris, and generate a segment-ID mask so attention can't bleed across users.

```
Fixed bucket [ 8192 tokens ]
├─ User A query (3000)
├─ User B query (4000)
├─ User C query (1000)
└─ padding      (192)   <-- the only waste
```

It helps to be concrete about what "the rectangle" physically is. When you compile with `BATCH_SIZE = 4, STATIC_SEQ_LEN = 8192`

, XLA reserves **one contiguous [4, 8192] static region** in the TPU's HBM — not four independent "rooms," but one big sheet the compiler hard-wires the array routes for. A single user rarely fills even one 8,192 lane, so the serving layer packs

```
[ One TPU processor: one static [4 x 8192] sheet ]

lane[0] (8192): [ A(2000) + B(5000) + C(1000) + pad(192) ]
lane[1] (8192): [ D(8000)                      + pad(192) ]
lane[2] (8192): [ E(3000) + F(3000) + G(2100)  + pad(92)  ]
lane[3] (8192): [ H(4000) + I(4000)            + pad(192) ]
```

Physically there are 4 lanes (32K of space); logically the proxy just crammed **9 ragged users (A–I)** into them. From the application side it looks like one TPU is concurrently servicing many small requests in parallel — but underneath it's one rigid sheet with a segment mask drawn over it. The reason the hardware wants one fat sheet instead of pre-carved small rooms is pure systolic-array physics: the bigger the matrix, the higher the array's fill rate and the fewer idle cycles between feeds.

Done right, MFU (Model FLOPs Utilization) approaches 100%. But notice what you just built: a high-throughput Go/C++ proxy in front of the cluster whose only job is to catch ragged input and pack it into rectangles in real time. On NVIDIA, that entire layer **does not exist**.

People assume `torch_xla`

abstracts the hardware away because `xm.xla_device()`

transparently targets both TPU and Trainium (thanks to the shared OpenXLA/PJRT runtime — `libtpu.so`

for TPU, `libneuronpjrt.so`

for Neuron). That's true for `model.to(device)`

and basic ops. It is emphatically *not* true for the parts that matter.

The `forward`

signature itself diverges:

```
# NVIDIA forward: ragged data + boundary index. Length is arbitrary every call.
def forward(self, input_ids, cu_seqlens, max_seqlen):
    return self.flash_attn_func(input_ids, cu_seqlens, max_seqlen)

# Static forward: fixed rectangle + a mask matrix you must build yourself.
def forward(self, input_ids, attention_mask):  # input_ids is [batch, FixedSeqLen]
    return self.static_attn_func(input_ids, attention_mask)
```

And it cascades all the way down:

| Component | NVIDIA pipeline | Trainium pipeline |
|---|---|---|
| Inference engine |
`vLLM` (CUDA), `TensorRT-LLM`
|
`NxD` / `vllm-neuron`
|
| Custom kernels | Triton, CUDA C++ (`FlashAttention` ) |
NKI (Neuron Kernel Interface), rewritten from scratch |
| Base image | `nvcr.io/nvidia/pytorch` |
AWS Neuron DLC |
| CI build artifact | weights + CUDA/Triton binaries | weights + NEFF static binaries per bucket
|
| Deploy target |
`g5` / `p5` instances |
`trn1` / `inf2` instances |
| Monitoring |
`nvidia-smi` , DCGM exporter |
`neuron-top` , Neuron exporter |

Two completely parallel worlds. Your CUDA container, your eval scripts, your autoscaling triggers — none of it carries over. vLLM's hardware-plugin mechanism gives you "one skin" at the business-logic layer, but the engine underneath is 100% separate code with separate bugs.

The data-type story isn't symmetric either. BF16 (which Google's TPU pioneered) is stable on both sides — its FP32-range exponent survives the `-1e9`

mask values without going NaN. But FP8, the current throughput play, favors NVIDIA: FP8 attention scores swing hard and need **dynamic scaling** at runtime to avoid clipping. A static compiler has to bake in a fixed scale factor at compile time, so on TPU/Trainium aggressive FP8 attention risks clipping that degrades model quality. "Just switch to FP8" is a one-liner on NVIDIA and a research project on static silicon.

This is the part that kills adoption and nobody puts on a slide. On NVIDIA there's a clean abstraction boundary:

```
[ AI engineer / data scientist ]
   architecture, hyperparams, eval
        │
        ▼  boundary: Hugging Face weights / standard PyTorch
        │
[ MLOps / LLMOps engineer ]
   drop into vLLM, configure PagedAttention, scale out
```

The data scientist never thinks about memory layout. The MLOps engineer never reads the attention math. They ship artifacts across a clean interface.

On TPU that wall **disappears**, because model structure is directly coupled to physical constraints:

`forward`

(AI engineer) are two halves of one design. Change the batching strategy and the math has to change in lockstep. You cannot split that across a spec doc.`if`

branch or changing layer count alters the compiled graph topology — and triggers JIT stalls or OOM in production. Debugging that requires dumping the XLA HLO graph, which pulls the AI engineer into an "infra" incident.The organizations that succeed on TPU — Google's Gemini team, Anthropic's Claude team, Meta's Llama-on-TPU group — abandoned the horizontal "data science dept / infra dept" split entirely. They run a single vertically-integrated team of people fluent in *both* the attention math and the compiler internals. Most companies cannot staff that, and the projects that try to keep the old division of labor die in a pile of compile errors and OOMs.

The whole calculus flips when **you control the input channel** so the shapes are predictable. Two clean examples:

`neuronx-distributed`

, with a Go/C++ proxy doing real-time packing), and Claude Code is — read cynically — the perfect input-locking channel that makes a Java-style chip worth the pain. Long-context workloads help too: a 200K-token prefill fills a 32K bucket with ~zero padding, so the static array's weakness evaporates exactly where Claude is strongest.The inverse is just as logical, and it explains why the *chat* UIs stay on NVIDIA. ChatGPT and Claude.ai's web frontends accept arbitrary text, surprise image uploads, and topic switches mid-conversation. The system can't predict the shape until the user hits send. That chaos is precisely what dynamic SIMT + `PagedAttention`

were built for.

The spec sheet was never lying about cost-per-token. It just wasn't pricing in the engineers, the forked pipeline, and the org redesign you have to buy first.
