{"slug": "jax-backends-and-devices", "title": "JAX backends and devices", "summary": "JAX defaults to loading data directly onto GPU memory when a CUDA-enabled version is installed, causing out-of-memory errors for large datasets that would fit in system RAM. The framework's `jax.devices()` function only returns devices from the default backend, which prioritizes GPU over CPU, requiring users to explicitly specify `jax.devices(\"cpu\")` to access CPU resources. This design choice means developers must manually manage device placement for data loading, unlike PyTorch's default behavior of loading into RAM.", "body_md": "There's nothing like writing your own code with a framework to clarify how things\nfit together! Continuing with my port of my [PyTorch LLM code](https://www.gilesthomas.com/llm-from-scratch) to\n[JAX](https://docs.jax.dev/), I wanted to load up a large dataset:\nthe 10,248,871,837 16-bit unsigned integers in the `train`\n\nsplit of\n[ gpjt/fineweb-gpt2-tokens](https://huggingface.co/datasets/gpjt/fineweb-gpt2-tokens).\nThat's just over 19GiB of data.\n\n``` python\nfrom safetensors.flax import load_file\n...\nfull_dataset = load_file(dataset_dir / f\"train.safetensors\")[\"tokens\"]\n```\n\nWhen I ran that, I got a CUDA out-of-memory error:\n\n```\njax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 19.09GiB.\n```\n\nThat makes sense! The allocation it was trying to do is exactly the size of the data I was trying to load. I have an RTX 3090 with 24 GiB, but some is already used up by the OS, various apps, and a model that the code creates earlier on.\n\nBut in PyTorch land, I was used to things being loaded into RAM by default, and only moved over to the GPU when I asked it to do that. JAX was clearly loading to the GPU by default. How could I stop it from doing that for this case? The load into the GPU was happening inside Safetensors, in code I couldn't directly control.\n\nUnderstanding how to do it helped me understand a little bit more about JAX.\n\nJAX has a function that looks relevant: `jax.devices`\n\n.\nWithout reading the docs, let's try running it. In my virtualenv, with the `jax[cuda13]`\n\npackage installed, I get this:\n\n``` python\nIn [1]: import jax\n\nIn [2]: all_devices = jax.devices()\n\nIn [3]: all_devices\nOut[3]: [CudaDevice(id=0)]\n```\n\nThat seems a bit weird! I do indeed have a CUDA device, but I also have a CPU, obviously. Why isn't it showing up?\n\nRunning the same code in another virtualenv, with just `jax`\n\ninstalled -- no CUDA -- gets this:\n\n``` python\nIn [1]: import jax\n\nIn [2]: all_devices = jax.devices()\nAn NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n\nIn [3]: all_devices\nOut[3]: [CpuDevice(id=0)]\n```\n\nOK, so it *did* recognise it this time. Feels like it might be time to RTFM.\n\nThe [ jax.devices docs](https://docs.jax.dev/en/latest/_autosummary/jax.devices.html) explain things\na bit:\n\n`jax.devices(`\n\n`backend=None`\n\n`)`\n\nReturns a list of all devices for a given backend.\n\n...\n\nIf\n\n`backend`\n\nis`None`\n\n, returns all the devices from the default backend. The default backend is generally`'gpu'`\n\nor`'tpu'`\n\nif available, otherwise`'cpu'`\n\n.\n\nOK. So JAX has multiple backends -- named that because they're classes of backend hardware that XLA (the compiler behind the JIT) targets. There is a default one, which is essentially going to be the \"best\" one available given the hardware configuration and the parts of JAX that are installed.\n\nWhen I had the CUDA version installed, it made the `gpu`\n\nbackend\ndefault, but when I didn't, it defaulted to `cpu`\n\n(and warned me). And because it\nonly shows the devices on the default backend, when that was `gpu`\n\n, I didn't see the CPU.\n\nHowever, you can specify which backend you want to use with that `backend`\n\nparameter, so let's go back to the\nvirtualenv with CUDA:\n\n```\nIn [4]: jax.devices(\"cpu\")\nOut[4]: [CpuDevice(id=0)]\n```\n\nGreat! So is there some way to list which backends are available?\n[Apparently not](https://github.com/jax-ml/jax/issues/6459) -- the recommended\nway appears to be to try loading devices for the different possibilities, and\ncatch `RuntimeErrors`\n\nto see which ones aren't available. Yuck.\n\nBut maybe that's not such a big deal. In PyTorch-land I was very much used to putting code like this near the start of my code:\n\n```\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n```\n\n...then moving models to the device:\n\n```\nmodel.to(device)\n```\n\n...and then moving data to the model's device as needed:\n\n```\ndevice = next(model.parameters()).device\n...\ninputs = inputs.to(device)\n```\n\nWhat I actually wanted was essentially what JAX does -- have everything on the fastest device available at all times -- but with specific exceptions. In particular, the one that started off this investigation: how would I put this huge array of training data on the CPU's RAM rather than the GPU's VRAM?\n\nI had a bit of a false start when I spotted that the `load_file`\n\nfunction in the\n[Safetensors FLAX API](https://huggingface.co/docs/safetensors/api/flax) has a `backend`\n\nparameter, but that appears to be more to do with how it loads up the file -- a backend\nin a different sense. And anyway, *backend* is not the right concept in JAX-land, as\nthe backend means just something generic like `gpu`\n\n-- for what we're trying to do, we want\nto load it onto a specific *device*.\n\nAfter some digging around, I discovered that JAX has a concept of a *default device*,\nwhich is the one used when it doesn't have any\nindication of where to put something. It makes sense that this will be on the default\nbackend -- indeed, it looks like it's essentially \"the first device in the list that\n`jax.devices`\n\nreturns for the default backend\".\n\nThere is a `jax_default_device`\n\nconfig option which\nyou can use to set it; you'd [normally use](https://docs.jax.dev/en/latest/config_options.html) `jax.config.update`\n\nor an environment\nvariable to change it.\n\nBut what if you only want to change it temporarily? I found [this documentation for\njax.default_device](https://docs.jax.dev/en/latest/_autosummary/jax.default_device.html).\n\nThe docs are more than a little confusing:\n\n`jax.default_device =`\n\n`<jax._src.config.State object>`\n\nContext manager for\n\n`jax_default_device`\n\nconfig option.Configure the default device for JAX operations. Set to a Device object (e.g.\n\n`jax.devices(\"cpu\")[0]`\n\n) to use that Device as the default device for JAX operations and jit’d function calls (there is no effect on multi-device computations, e.g. pmapped function calls). Set to None to use the system default device.\n\nThat `=`\n\nnear the start tripped me up, as I missed the words \"Context manager\" just\nbelow, and the odd `State`\n\ntype, and tried this:\n\n```\njax.default_device = jax.devices(\"cpu\")[0]\nfull_dataset = load_file(dataset_dir / f\"train.safetensors\")[\"tokens\"]\njax.default_device = None\n```\n\nI still got the CUDA OOM, though, so I reread the docs, spotted the \"context manager\" bit, swore violently, and tried this:\n\n```\nwith jax.default_device(jax.devices(\"cpu\")[0]):\n    full_dataset = load_file(dataset_dir / f\"train.safetensors\")[\"tokens\"]\n```\n\n...which works. It looks like the equals sign in the docs is being used to mean something\nvery different to what you'd normally use it for, and they decided not to actually\ndocument the signature of the context manager. Heigh ho. I guess\n[documentation is hard](https://www.reddit.com/r/ProgrammerHumor/comments/gv6yhy/documentation_hell/).\n\nStill, at least now I have a solution. And as I said earlier, doc grumbles aside, the shape of the code might wind up being a little less fiddly than PyTorch. The default location of things I create is the fastest hardware I have, which is what I want. And for the rare exceptions when I don't want to use that, there is a reasonably simple (now that I know it) way to say where I want things to go.\n\nI'll call that a win :-) The only thing I'll need to remember is that when, in\nmy training loop, I want to use subsets of that in-RAM tensor, I'll need to move them\nto the GPU. [ jax.device_put](https://docs.jax.dev/en/latest/_autosummary/jax.device_put.html) looks\nlike the right tool for that.", "url": "https://wpnews.pro/news/jax-backends-and-devices", "canonical_source": "https://www.gilesthomas.com/2026/06/jax-backends-and-devices", "published_at": "2026-06-05 19:30:00+00:00", "updated_at": "2026-06-05 19:07:05.612266+00:00", "lang": "en", "topics": ["machine-learning", "large-language-models", "ai-tools", "ai-infrastructure"], "entities": ["JAX", "PyTorch", "CUDA", "Safetensors", "RTX 3090", "Hugging Face"], "alternates": {"html": "https://wpnews.pro/news/jax-backends-and-devices", "markdown": "https://wpnews.pro/news/jax-backends-and-devices.md", "text": "https://wpnews.pro/news/jax-backends-and-devices.txt", "jsonld": "https://wpnews.pro/news/jax-backends-and-devices.jsonld"}}