{"slug": "jax-commitment-issues", "title": "Jax: Commitment Issues", "summary": "JAX's default_device context manager places arrays on the specified device but does not commit them, allowing JAX to move them to other devices. This caused array lookups to take over a second by triggering GPU usage, but using jax.device_put to commit the array to the CPU reduced lookup times to milliseconds.", "body_md": "## JAX: commitment issues\n\nImagine you have JAX code like this, and run it on a machine with CUDA set up:\n\n```\n    key = jax.random.key(42)\n\n    cpu0 = jax.devices(\"cpu\")[0]\n    with jax.default_device(cpu0):\n        array = jax.random.randint(\n            key,\n            (530640, 6, 1024),\n            0, 50_000,\n            dtype=jax.numpy.uint16\n        )\n        array.block_until_ready()\n\n    item = array[0]\n    item.block_until_ready()\n```\n\nWe're creating a big array, blocking until it's ready (JAX is asynchronous, so this makes sure that it's actually finished creating it), then getting the first item, and as a belt-and-braces thing making sure that that is ready too. How long do you think those last two lines -- a simple retrieval of a 6 x 1024 array from a larger one -- will take? Some tiny fraction of a second would seem reasonable.\n\nBut running it on my machine just now, the answer is a bit of a surprise: just over 5 seconds. And if you try to\nget `array[1]`\n\nimmediately afterwards, it still takes about 1.2s. Further lookups into\n`array`\n\nconsistently take more than a second -- so while the larger\ninitial number might be something to do with setup -- maybe internal stuff being JITted -- that's clearly not the whole story.\nSomething is making these seemingly-simple array lookups take much longer than you'd\nexpect them to.\n\nLet's dig into that.\n\n### A bit of background\n\nFirst things first, why would you want to do that slightly strange dance with the\n`jax.default_device`\n\ncontext manager in the first place, rather than telling `randint`\n\nwhat\ndevice you want to use (eg. with `out_sharding`\n\n)?\n\nI'm writing some LLM training\ncode, and want to load my training dataset. I don't want to load it into the VRAM\non the GPU -- that would be a waste of valuable GPU resources -- so I need it in\nthe CPU-side memory. I'm using Safetensors, which will\n[load stuff onto the system's default device](/2026/06/jax-backends-and-devices).\nSo I need to override that temporarily to make sure that the dataset is loaded onto\nthe device where I want it.\n\nI initially discovered this problem when I tried to iterate over the resulting array in my training\nloop; the code above is a simplified version of that -- a minimal repro of the issue.\nAnd it's a serious one! If each iteration has an overhead of 1.2s just to get 6,144 tokens\nready for the model, JAX will max out at about 5,000 tokens per second\nof training speed *just due to that overhead* -- a real forward and backward pass plus\nan optimiser step will obviously make things even slower. For comparison, my PyTorch training loop managed almost 20,000 tokens/second\non the same hardware: all steps from getting the training data, putting it on the GPU,\nand doing the actual training.\n\n### Debugging\n\nSo, let's look at that code again. We've created\nour variable `array`\n\non the CPU explicitly, and indeed if you print `array.device`\n\n,\nit says `CpuDevice(id=0)`\n\n. But if you print the device of the `item`\n\n, you get\n`CudaDevice(id=0)`\n\n. What's worse, if you watch `nvtop`\n\nwhile the code is running,\nas soon as it hits the lookup into the array, it starts using the GPU -- for each one,\nthere's a spike in GPU usage.\n\nSo, what gives? We asked JAX to put the array on the CPU, but now it's doing GPU work, and putting the items there.\n\nThe problem is that when you create an array using the `default_device`\n\ncontext manager,\nit is placed on the specified device, but it's not [committed](https://docs.jax.dev/en/latest/_autosummary/jax.Array.committed.html) to it.\nIf an array is not committed to its device, then JAX will feel free to move it\naround to others.\n\nIn order to commit an array to a device, you need to use `jax.device_put`\n\nexplicitly\nstating which device you want it on. Running\nthe same code, but with this:\n\n```\narray = jax.device_put(array, cpu0)\n```\n\n...immediately before the lookup into the array changes the numbers drastically; the first lookup takes about 0.95s on my machine, the second 0.0002s, and then subsequent ones less than 0.0001s.\n\n### Some more detailed tests\n\nI decided to exercise this in depth, and wrote [this script](https://github.com/gpjt/jax-gpt2-from-scratch/blob/3d28bf64c87e92a62dceac26f32eaa2248ece12c/array_speed_test.py).\nIf you run it without the `--commit`\n\ncommand line flag, it will create the array,\nthen iterate over the first ten items, measuring how long it takes to get each one.\nRunning it just now:\n\n- Getting the zeroth item from the array took about 5.4s.\n- Each subsequent one consistently took about 1.2s\n\nWith the `--commit`\n\nflag, it uses `device_put`\n\nto explicitly commit the array to the\nCPU. Running that:\n\n- Getting the zeroth item from the array took about 0.95s\n- Each subsequent one took less than 0.0002s.\n\nNow, that didn't quite cover my use case -- what if, I wondered, the slow operation\nwas putting things onto the GPU? The script also has a `--put_items_to_gpu`\n\nflag to\ndo that -- after getting each item, it uses `device_put`\n\n. With that flag:\n\n- Getting the zeroth item from the array took about 0.86s, and putting it on the GPU took 0.02s.\n- Subsequent items had \"get\" times similar to the previous run, and \"put\" times of about 0.0006s.\n\nSo, there's still a small startup penalty -- perhaps JAX is having to JIT some of its internal stuff -- but a perfectly decent speed after that. Commitment works!\n\n### Wrapping up\n\nI'm still building my mental model of how JAX works, and working out exactly what is going on here is proving a bit tricky. The split between a committed and an uncommitted array seems clear; the former is tied to a device, while JAX will move the latter around as needed.\n\nIt also makes a certain amount of sense that it would want to move the items to the GPU; it is, after all, the default device. But I'm less clear on why that was so slow, compared to the manual process of getting the item then putting it there.\n\nHypothesis: the array is on the CPU's RAM, but not committed there. We ask for an item from that array, and maybe JAX wants that to be on the default device, the GPU. So it moves the entire \"parent\" array there, extracts that item, and then returns that. Then next time around when we ask for the next item, it does the same thing again.\n\nPlausible? Maybe, but it does sound a bit pathological!\n\nAnyway, at the end of the day, I have a solid new heuristic of my own: if you want something\nto definitely be on some specific device, make sure that you nail it down there with\n`device_put`\n\n. And then you won't have commitment issues like these.", "url": "https://wpnews.pro/news/jax-commitment-issues", "canonical_source": "https://www.gilesthomas.com/2026/06/jax-commitment-issues", "published_at": "2026-06-15 20:40:40+00:00", "updated_at": "2026-06-15 21:04:35.268099+00:00", "lang": "en", "topics": ["machine-learning", "ai-tools"], "entities": ["JAX", "CUDA", "GPU", "CPU", "Safetensors", "PyTorch"], "alternates": {"html": "https://wpnews.pro/news/jax-commitment-issues", "markdown": "https://wpnews.pro/news/jax-commitment-issues.md", "text": "https://wpnews.pro/news/jax-commitment-issues.txt", "jsonld": "https://wpnews.pro/news/jax-commitment-issues.jsonld"}}