{"slug": "using-safetensors-with-flax", "title": "Using Safetensors with Flax", "summary": "A developer porting PyTorch LLM code to JAX using Flax encountered difficulties when attempting to store model checkpoints with Safetensors, as the library's Flax API expects flat dictionaries but Flax's `nnx.State.to_pure_dict` produces nested structures. The mismatch caused a confusing \"dtype object is not covered\" error because Safetensors silently converted nested dicts into NumPy object arrays instead of validating input types. The issue highlights a documentation gap in Safetensors' Flax integration that can trip up users migrating from PyTorch to JAX-based workflows.", "body_md": "I'm porting my [PyTorch LLM code](https://www.gilesthomas.com/llm-from-scratch) to\n[JAX](https://www.gilesthomas.com/2026/05/on-first-looking-into-jax), using\n[Flax](https://flax.readthedocs.io/en/stable/) as the neural network layer.\nFor various reasons I wanted to use [Safetensors](https://huggingface.co/docs/safetensors/index)\nto store checkpoints of the model. It took a little while to get it working;\nhere's the trick I learned.\n\nIf you look at the Safetensors docs, you'll see that it doesn't mention a JAX implementation --\nindeed, searching for \"safetensors jax\" at the time I'm writing this gives you a link\nto [this GitHub repo by Alvaro Bartolome](https://github.com/alvarobartt/safejax) -- which was last updated in\n2023.\n\nHowever, if you look more closely at the docs, they *do* have a link to the\n[Flax API](https://huggingface.co/docs/safetensors/api/flax). I feel this is somewhat\nmisnamed, as it is actually a JAX API. There's no reference (again, as of the time of\nwriting) to Flax in the source -- it's all just JAX code. And in fact Bartolome's library\nuses it under the hood.\n\nThere is one problem, though. The API works with simple single-level dictionaries,\nwith strings mapping directly to JAX arrays. For example, the `save_file`\n\nfunction has this\nsignature:\n\n``` python\ndef save_file(\n    tensors: Dict[str, Array],\n    filename: Union[str, os.PathLike],\n    metadata: Optional[Dict[str, str]] = None,\n) -> None\n```\n\nThis can cause problems if you're not careful. If you look at the\n[Flax documentation on checkpointing](https://flax.readthedocs.io/en/stable/guides/checkpointing.html),\nit suggests that you use [Orbax](https://orbax.readthedocs.io/en/latest/index.html) 1,\nwhich has its own API and file format, but then goes on to say:\n\nWhen interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the\n\n`nnx.State.to_pure_dict`\n\nand`nnx.State.replace_by_pure_dict`\n\nAPI to convert an`nnx.State`\n\nto and from pure nested dictionaries.\n\nI initially put two and two together -- that and the dictionary-based API for Safetensors -- and got five, and tried feeding one of those \"pure\" dicts into Safetensors. I got a very confusing error:\n\n```\nSafetensorError: dtype object is not covered\n```\n\nIt's worth digging in to why that happens.\n\nThe problem is that although Safetensors is expecting a dict of strings mapping to\ntensors, it doesn't check that that is what it actually gets. And while the dictionaries\nfrom `nnx.State.to_pure_dict`\n\nare \"pure\", they are also nested (as the docs say!). Even for the simple\nmodel I was working with, I got a structure like this:\n\n```\n{\n    'output_head': {\n        'kernel': Array([...], dtype=float32)\n    },\n    'token_embedding': {\n        'embedding': Array([...], dtype=float32)\n    }\n}\n```\n\nSo, we had strings mapping to dicts, and those dicts mapped from strings to the JAX arrays. More complex models would have had deeper dict structures.\n\nNow, internally inside Safetensors, the Flax/JAX API is a simple wrapper. It\niterates over the keys in the dictionary it's been provided with, and tries to convert\ntheir respective values into NumPy arrays. It does that by passing them into\nNumPy's `asarray`\n\nfunction, which accepts things like lists, tuples, and NumPy arrays,\nand converts them into arrays. JAX's own `Array`\n\nclass exposes an interface that it\nrecognises, so they're converted without trouble.\n\nOnce it's done that, it passes the result to a lower-level Rust implementation that actually converts everything to Safetensors format.\n\nBut because Safetensors didn't check types, in my case it was iterating over the top level of the dict, trying to convert the values to NumPy arrays, and got something like this:\n\n```\n{\n    'output_head': numpy.array({'kernel': Array([...], dtype=float32)}, dtype=object),\n    'token_embedding': numpy.array({'embedding': Array([...], dtype=float32)}, dtype=object)\n}\n```\n\nThat is -- because it assumed that the values in the top-level dict were JAX `Array`\n\ns,\nit blindly tried to convert them to NumPy arrays. But they were dicts (that\nhappened to map from strings to arrays) -- and if you ask `asarray`\n\nto create an array based on\na random object, it happily does so and wraps that object in a NumPy array, with a `dtype`\n\nof `object`\n\n.\n\nWhen that is then fed into the lower-level Rust code that is trying to write the\nfile, it encounters NumPy arrays that have a `dtype`\n\nit can't handle, `object`\n\n--\nhence that error:\n\n```\nSafetensorError: dtype object is not covered\n```\n\nIt all makes sense when you read through the code, but I was a bit perplexed for a while!\n\nI think all this might be the reason why Bartolome created his GitHub repo. In the README, he says that:\n\nThere are no plans from HuggingFace to extend safetensors to support anything more than tensors e.g.\n\n`FrozenDicts`\n\n, see their response at[huggingface/safetensors/discussions/138].So the motivation to create\n\n`safejax`\n\nis to easily provide a way to serialize`FrozenDicts`\n\nusing safetensors as the tensor storage format\n\nHowever, you don't need to use that library to serialise simple Flax models.\n\nConsider how PyTorch models get serialised to Safetensors; my LLMs have keys with names\nlike `out_head.weight`\n\n, `pos_emb.weight`\n\n, and `trf_blocks.0.att.out_proj.weight`\n\n.\nThey're \"flat\" dictionaries mapping strings to PyTorch Tensors, similar to what Safetensors\nwants for these Flax ones, but they use dots to separate different levels, with integers\nfor list items and strings for field names.\n\nLooking at the pure-dict structure I had for my model:\n\n```\n{\n    'output_head': {\n        'kernel': Array([...], dtype=float32)\n    },\n    'token_embedding': {\n        'embedding': Array([...], dtype=float32)\n    }\n}\n```\n\n...you can see that you could walk the dictionary structure to generate keys like\n`output_head.kernel`\n\nand `token_embedding.embedding`\n\n. That would be easy enough to code\nup.\n\nBut -- as Adithya Dsilva [points out on GitHub](https://github.com/google/flax/discussions/4900) -- you can get there even faster by using\n[ nnx.to_flat_state](https://flax.readthedocs.io/en/stable/api_reference/flax.nnx/state.html#flax.nnx.to_flat_state).\nThat returns a (non-dict) structure like this:\n\n```\nFlatState([\n  (('output_head', 'kernel'), Param( # 786,432 (3.1 MB)\n    value=Array([[ 2.3581974e-02,  3.0957451e-02, -3.5088759e-02, ...,\n            -4.5880198e-02,  5.3717274e-02, -2.6590331e-02],\n           ...,\n           [-9.6302675e-03, -3.3276502e-02,  5.7173111e-02, ...,\n            -7.9063717e-03,  2.0532632e-02,  5.4753982e-02]], dtype=float32)\n  )),\n  (('token_embedding', 'embedding'), Param( # 786,432 (3.1 MB)\n    value=Array([[ 0.00273973, -0.01754938,  0.04656043, ..., -0.04276522,\n            -0.03986642, -0.00781331],\n           ...,\n           [ 0.01421758, -0.0219186 , -0.01701825, ..., -0.00793659,\n             0.00500103,  0.03839901]], dtype=float32)\n  ))\n])\n```\n\nIf you iterate over that `FlatState`\n\n, you get tuples where the first element is that\ntuple of strings, like `('output_head', 'kernel')`\n\n, and the second is a `Param`\n\nobject wrapping\nthe JAX `Array`\n\n.\nThe tuples mirror the dot-separated string format in the PyTorch-style Safetensors files.\n\n`Param`\n\nobjects also implement an interface that `asarray`\n\ncan understand,\nso you can quickly and easily convert the `FlatState`\n\nto a regular dict for Safetensors:\n\n``` python\n    from safetensors.flax import save_file\n\n    ...\n\n    model_state = nnx.state(model)\n    flat_state = nnx.to_flat_state(model_state)\n    simple_dict = {}\n    for tuple_key, param in flat_state:\n        key = \".\".join(str(key) for key in tuple_key)\n        simple_dict[key] = param\n\n    save_file(simple_dict, \"model.safetensors\")\n```\n\n(You need to wrap `key`\n\nin a `str`\n\nbecause if you have a `nnx.Sequential`\n\nin your model, the\nitem in the tuple will get an integer index rather than a string).\n\nYou can go the other way pretty easily too; given a model, you can load the saved\ncheckpoint into it like this (because `from_flat_state`\n\naccepts raw JAX `Array`\n\ns\nin place of explicit `Param`\n\ns):\n\n``` python\n    from safetensors.flax import load_file\n\n    ...\n\n    simple_dict = load_file(\"model.safetensors\")\n\n    dict_flat_state = {}\n    for key, array in simple_dict.items():\n        elements = key.split(\".\")\n        list_key = []\n        for element in elements:\n            try:\n                list_key.append(int(element))\n            except ValueError:\n                list_key.append(element)\n        dict_flat_state[tuple(list_key)] = array\n\n    new_flat_state = nnx.from_flat_state(dict_flat_state)\n    nnx.update(model, new_flat_state)\n```\n\nA little more work than I'd ideally like, but given that it can be tucked away\nin general `save_checkpoint`\n\n/`load_checkpoint`\n\nfunctions, not too big a deal.\n\nHope that's of use for other people coming across this problem!\n\nI'm beginning to feel a bit swamped with all of these libraries with names\nending in -ax. It reminds me of\n[the names of the characters in Asterix's village](https://asterix.fandom.com/wiki/List_of_Asterix_characters#Villagers_of_the_Indomitable_Village)... [↩](https://www.gilesthomas.com/feed/rss.xml#fnref-1)", "url": "https://wpnews.pro/news/using-safetensors-with-flax", "canonical_source": "https://www.gilesthomas.com/2026/06/flax-and-safetensors", "published_at": "2026-06-04 23:30:00+00:00", "updated_at": "2026-06-04 23:07:06.941473+00:00", "lang": "en", "topics": ["machine-learning", "neural-networks", "ai-tools", "ai-infrastructure", "large-language-models"], "entities": ["PyTorch", "JAX", "Flax", "Safetensors", "Hugging Face", "Alvaro Bartolome", "safejax", "GitHub"], "alternates": {"html": "https://wpnews.pro/news/using-safetensors-with-flax", "markdown": "https://wpnews.pro/news/using-safetensors-with-flax.md", "text": "https://wpnews.pro/news/using-safetensors-with-flax.txt", "jsonld": "https://wpnews.pro/news/using-safetensors-with-flax.jsonld"}}