# Jax: Commitment Issues

> Source: <https://www.gilesthomas.com/2026/06/jax-commitment-issues>
> Published: 2026-06-15 20:40:40+00:00

## JAX: commitment issues

Imagine you have JAX code like this, and run it on a machine with CUDA set up:

```
    key = jax.random.key(42)

    cpu0 = jax.devices("cpu")[0]
    with jax.default_device(cpu0):
        array = jax.random.randint(
            key,
            (530640, 6, 1024),
            0, 50_000,
            dtype=jax.numpy.uint16
        )
        array.block_until_ready()

    item = array[0]
    item.block_until_ready()
```

We're creating a big array, blocking until it's ready (JAX is asynchronous, so this makes sure that it's actually finished creating it), then getting the first item, and as a belt-and-braces thing making sure that that is ready too. How long do you think those last two lines -- a simple retrieval of a 6 x 1024 array from a larger one -- will take? Some tiny fraction of a second would seem reasonable.

But running it on my machine just now, the answer is a bit of a surprise: just over 5 seconds. And if you try to
get `array[1]`

immediately afterwards, it still takes about 1.2s. Further lookups into
`array`

consistently take more than a second -- so while the larger
initial number might be something to do with setup -- maybe internal stuff being JITted -- that's clearly not the whole story.
Something is making these seemingly-simple array lookups take much longer than you'd
expect them to.

Let's dig into that.

### A bit of background

First things first, why would you want to do that slightly strange dance with the
`jax.default_device`

context manager in the first place, rather than telling `randint`

what
device you want to use (eg. with `out_sharding`

)?

I'm writing some LLM training
code, and want to load my training dataset. I don't want to load it into the VRAM
on the GPU -- that would be a waste of valuable GPU resources -- so I need it in
the CPU-side memory. I'm using Safetensors, which will
[load stuff onto the system's default device](/2026/06/jax-backends-and-devices).
So I need to override that temporarily to make sure that the dataset is loaded onto
the device where I want it.

I initially discovered this problem when I tried to iterate over the resulting array in my training
loop; the code above is a simplified version of that -- a minimal repro of the issue.
And it's a serious one! If each iteration has an overhead of 1.2s just to get 6,144 tokens
ready for the model, JAX will max out at about 5,000 tokens per second
of training speed *just due to that overhead* -- a real forward and backward pass plus
an optimiser step will obviously make things even slower. For comparison, my PyTorch training loop managed almost 20,000 tokens/second
on the same hardware: all steps from getting the training data, putting it on the GPU,
and doing the actual training.

### Debugging

So, let's look at that code again. We've created
our variable `array`

on the CPU explicitly, and indeed if you print `array.device`

,
it says `CpuDevice(id=0)`

. But if you print the device of the `item`

, you get
`CudaDevice(id=0)`

. What's worse, if you watch `nvtop`

while the code is running,
as soon as it hits the lookup into the array, it starts using the GPU -- for each one,
there's a spike in GPU usage.

So, what gives? We asked JAX to put the array on the CPU, but now it's doing GPU work, and putting the items there.

The problem is that when you create an array using the `default_device`

context manager,
it is placed on the specified device, but it's not [committed](https://docs.jax.dev/en/latest/_autosummary/jax.Array.committed.html) to it.
If an array is not committed to its device, then JAX will feel free to move it
around to others.

In order to commit an array to a device, you need to use `jax.device_put`

explicitly
stating which device you want it on. Running
the same code, but with this:

```
array = jax.device_put(array, cpu0)
```

...immediately before the lookup into the array changes the numbers drastically; the first lookup takes about 0.95s on my machine, the second 0.0002s, and then subsequent ones less than 0.0001s.

### Some more detailed tests

I decided to exercise this in depth, and wrote [this script](https://github.com/gpjt/jax-gpt2-from-scratch/blob/3d28bf64c87e92a62dceac26f32eaa2248ece12c/array_speed_test.py).
If you run it without the `--commit`

command line flag, it will create the array,
then iterate over the first ten items, measuring how long it takes to get each one.
Running it just now:

- Getting the zeroth item from the array took about 5.4s.
- Each subsequent one consistently took about 1.2s

With the `--commit`

flag, it uses `device_put`

to explicitly commit the array to the
CPU. Running that:

- Getting the zeroth item from the array took about 0.95s
- Each subsequent one took less than 0.0002s.

Now, that didn't quite cover my use case -- what if, I wondered, the slow operation
was putting things onto the GPU? The script also has a `--put_items_to_gpu`

flag to
do that -- after getting each item, it uses `device_put`

. With that flag:

- Getting the zeroth item from the array took about 0.86s, and putting it on the GPU took 0.02s.
- Subsequent items had "get" times similar to the previous run, and "put" times of about 0.0006s.

So, there's still a small startup penalty -- perhaps JAX is having to JIT some of its internal stuff -- but a perfectly decent speed after that. Commitment works!

### Wrapping up

I'm still building my mental model of how JAX works, and working out exactly what is going on here is proving a bit tricky. The split between a committed and an uncommitted array seems clear; the former is tied to a device, while JAX will move the latter around as needed.

It also makes a certain amount of sense that it would want to move the items to the GPU; it is, after all, the default device. But I'm less clear on why that was so slow, compared to the manual process of getting the item then putting it there.

Hypothesis: the array is on the CPU's RAM, but not committed there. We ask for an item from that array, and maybe JAX wants that to be on the default device, the GPU. So it moves the entire "parent" array there, extracts that item, and then returns that. Then next time around when we ask for the next item, it does the same thing again.

Plausible? Maybe, but it does sound a bit pathological!

Anyway, at the end of the day, I have a solid new heuristic of my own: if you want something
to definitely be on some specific device, make sure that you nail it down there with
`device_put`

. And then you won't have commitment issues like these.
