{"slug": "flax-debugging-making-a-hash-of-things", "title": "Flax debugging: making a hash of things", "summary": "A developer debugging a JAX/Flax NNX training loop discovered that the loss was stuck at 10.82, indicating the model was performing no better than random guessing. The issue was traced to the training loop's plumbing rather than the model or loss function, as gradients were not being properly applied to the model parameters.", "body_md": "## Flax debugging: making a hash of things\n\nI was debugging an issue with a JAX/Flax NNX training loop the other day, and found a neat little trick to help debug it. Specifically, I wanted to see if the issue was with my model, my loss function, my optimiser settings, or the \"plumbing\" of the training loop itself -- were gradients actually coming through and being applied to the parameters?\n\nI could print out the loss and the gradients, but printing out the parameters to see if they were changing was unhelpful -- any given update might only change a small number of parameters, or might change them such a small amount that I'd not notice -- especially given that the model had 77 million of them!\n\nLet's take a look.\n\n### The world's worst LLM\n\nI am building an LLM from scratch in JAX and Flax NNX, and at this stage I'm trying to get the training loop right. As a simple test, I've just implemented the \"shell\" of the LLM -- the token embeddings on the input side, and the final linear layer for an output head, wired directly together. My plan was to train that so that given a sequence, instead of predicting next tokens for each position, it would \"predict\" the sequence itself -- that is, I might train it with the input\n\n```\nThe fat cat sat on the mat\n```\n\n...and the target\n\n```\nThe fat cat sat on the mat\n```\n\n...rather than the normal setup for an LLM, where you feed it\n\n```\nThe fat cat sat on the\n```\n\n...and give it targets of\n\n```\nfat cat sat on the mat\n```\n\nSo, in LLM terms, I'd be training a model to project from vocab space to a learned embedding\nspace where each token had a distinct-enough embedding for the output head to be able\nto reliably project back to logits in vocab space. There's\n[a bit of background here if that was all Greek to you](/2025/05/llm-from-scratch-15-from-context-vectors-to-logits).\n\nHere's the core part of the code I was working with, the `train_step`\n\nfunction, which\nseems to be the traditional JAX name for the JITted part of your code that does the\nforward pass through the model, works out the gradients, and then applies them to update\nthe model:\n\n``` python\n@jax.jit\ndef train_step(model, optimizer, inputs, targets):\n    loss, grads = nnx.value_and_grad(calculate_loss)(model, inputs, targets)\n    optimizer.update(model, grads)\n    return loss\n```\n\nI'd based it on the [\"Basic Usage\" example](https://flax.readthedocs.io/en/stable/#basic-usage) that's\ncurrently right there on the front page of the Flax site. Seasoned Flax veterans will probably\nspot the issue right away, but it wasn't obvious to me -- so it was time to dig in.\n\n### Dealing with loss\n\nThe problem was that loss was not dropping -- indeed, taken to two decimal places, it was stuck at 10.82. The digits\nto the right of that changed for each batch, but the first four did not. Now, this model was\nusing the GPT-2 tokeniser, and 10.82 is exactly the loss that you'd expect if the model\nwas essentially guessing randomly -- if you convert it to\n[perplexity](/2025/10/llm-from-scratch-21-perplexed-by-perplexity) by\ncalculating , you get about 50,011 -- which is very close to the GPT-2\nvocab size of 50,257. Perplexity is, loosely, the number of tokens that the model\nwas trying to choose between for a typical input -- so a perplexity equal to the vocab size\nis what you'd expect of a random model that is getting it right about one in 50,257\ntimes.\n\nThat said, getting that loss consistently was a solid validation of my loss function! It's vanishingly unlikely that it would have been getting that specific number so consistently if I'd made a mess of that. The tiny variations I was seeing in the third and subsequent decimal places would make sense, as they could easily be due to the variations in the contents of the different batches.\n\n### Gradient descent into madness\n\nSo was it that the gradients were somehow zero, or NaNs, or something else that couldn't\nbe usefully applied to the model by the optimiser? I printed them out in the `train_step`\n\nfunction (removing the `jit`\n\ndecorator, as otherwise the `print`\n\ns would only get executed\nin the initial JIT pass through the function to compile it -- not when it had actual data 1).\n\nThe result was values like this:\n\n```\nState({\n  'output_head': {\n    'kernel': Param( # 38,597,376 (154.4 MB)\n      value=Array([[-2.6879393e-06, -1.2799728e-04,  2.6441864e-09, ...,\n              -1.0780521e-09, -1.9232946e-09,  1.2057198e-04],\n             [ 7.2428256e-06, -9.0873800e-05,  1.9621261e-08, ...,\n               1.9959407e-08,  2.0515712e-08, -1.1401048e-06],\n             [-2.4080187e-05,  1.0717572e-04, -4.7910085e-09, ...,\n              -7.3136892e-09, -5.4990306e-09,  1.4717734e-04],\n             ...,\n             [ 1.9500087e-05,  1.4264552e-05, -3.0880422e-08, ...,\n              -3.0595814e-08, -3.7087858e-08, -1.2066610e-06],\n             [ 1.8085115e-05,  7.6247423e-05, -3.0720415e-08, ...,\n              -3.1052533e-08, -3.1693808e-08, -9.7857817e-05],\n             [ 5.2281484e-06, -1.4398852e-04,  6.2573882e-08, ...,\n               5.5977843e-08,  6.6571232e-08, -1.0639715e-05]], dtype=float32)\n    )\n  },\n  ...\n```\n\nThose looked plausible enough -- pretty small, but not so tiny that I'd expect them to have no effect at all with my learning rate of 0.0014. It was time to dig into the training loop's plumbing.\n\n### Plumbing the depths\n\nThe obvious suspect was the update step -- was that call to `optimizer.update`\n\nactually changing\nthe parameters at all? Flax's NNX API is a bit odd compared to the normal\n[JAX functional way of doing things](https://www.gilesthomas.com/2026/05/on-first-looking-into-jax).\nIn vanilla JAX code you would expect to do something like this to apply gradients:\n\n```\n    new_parameters = jax.tree.map(\n        lambda p, g: p - g * learning_rate,\n        old_parameters,\n        grads,\n    )\n```\n\nThat is, you get the new parameters by applying a transformation to the old ones.\n\nNNX, by contrast, is more PyTorch-flavoured. It updates the parameters in-place, using a function with a side effect of mutating one of its parameters:\n\n```\noptimizer.update(model, grads)\n```\n\n...rather than something more functional like this imaginary API:\n\n```\nmodel = optimizer.apply(model, grads)\n```\n\nI could easily imagine that I'd got something wrong that would break that in-place update, as it has the feel of something that would have to be quite delicately implemented on top of a functional system like JAX.\n\nBut how could I see whether the parameters were changing, when there were 77 million of them and they would be being updated (based on gradients like -2.6879393e-06 and a learning rate of 1.4e-3) in the ninth decimal place or beyond? Printing the arrays out was a non-starter!\n\n### Hashing it out\n\nAfter a little thought, I realised that the solution was to use hashes. Even tiny changes in the parameters' values would change their hashes drastically. So if the parameters were not being updated, as I suspected, I'd see constant hashes. If they were being updated, even by a minuscule amount, then the hashes would change.\n\n[This GitHub discussion](https://github.com/jax-ml/jax/discussions/8352) pointed me\nin the right direction: if I could get the parameters as pure JAX arrays, I could do this:\n\n```\nprint(hash(np.asarray(some_array).tobytes()))\n```\n\n...where `np`\n\nis just `numpy`\n\n. That would produce a hash that was stable for the life of\nthis run -- the same parameters would always have the same hash, and different ones would\ndiffer, just as we want. It could vary from run to run (Python uses different hash seeds in\neach new interpreter), but that wouldn't matter for\nthis kind of debugging.\n\nI wasn't sure what the structure of my Flax model's parameters was, but printing them out in the training loop told me:\n\n```\nEmbed( # Param: 38,597,376 (154.4 MB)\n  embedding=Param( # 38,597,376 (154.4 MB)\n    value=Array(shape=(50257, 768), dtype=dtype('float32'))\n  ),\n  ...\n)\nLinear( # Param: 38,597,376 (154.4 MB)\n  kernel=Param( # 38,597,376 (154.4 MB)\n    value=Array(shape=(768, 50257), dtype=dtype('float32'))\n  ),\n  ...\n)\n```\n\nSo, guided by that, I added these lines to the training loop:\n\n```\nprint(hash(np.asarray(model.token_embedding.embedding.value).tobytes()))\nprint(hash(np.asarray(model.output_head.kernel.value).tobytes()))\n```\n\nObviously copying the arrays around and converting them like that would slow things down, but for debugging purposes, it looked solid.\n\nI kicked off the training loop, and the problem was clear:\n\n```\n  0%|                         | 43/530640 [00:06<13:39:02, 10.80it/s, loss=10.824, tps=43,576]\n5694185712877458479\n-5759723708627894111\n  0%|                         | 43/530640 [00:06<13:39:02, 10.80it/s, loss=10.824, tps=43,897]\n5694185712877458479\n-5759723708627894111\n```\n\n...and so on. The hashes were not changing, so the model's parameters were not being updated, even by a tiny amount. Gotcha!\n\nThe problem turned out, as I had suspected, to be related to the in-place updates that NNX does. Like I said earlier, I'd based my training loop on the \"Basic Usage\" example on the Flax site -- but I'd messed up one important thing. I had this:\n\n``` python\n@jax.jit\ndef train_step(model, optimizer, inputs, targets):\n    loss, grads = nnx.value_and_grad(calculate_loss)(model, inputs, targets)\n    optimizer.update(model, grads)\n    return loss\n```\n\n...and they had this:\n\n``` python\n@nnx.jit  # automatic state propagation\ndef train_step(model, optimizer, x, y):\n  loss_fn = lambda model: ((model(x) - y) ** 2).mean()\n  loss, grads = nnx.value_and_grad(loss_fn)(model)\n  optimizer.update(model, grads)  # in-place updates\n  return loss\n```\n\nYou can see a number of differences -- for example, they're baking the inputs and targets into\nthe lambda they're using for the loss function through a lexical closure, and that means\nthat they're only passing in the model to the version of it wrapped in `value_and_grad`\n\n. But\nnone of that matters! The real difference is actually nicely highlighted with a comment,\nbut I'd completely managed to miss it. Right at the start, where I had `@jax.jit`\n\n, they had this:\n\n```\n@nnx.jit  # automatic state propagation\n```\n\nIt 100% makes sense that in order to support this kind of non-functional, in-place updating of the model's parameters, you have to have a modified version of the JIT decorator. And I was just using the standard, functional pure-JAX one.\n\nFixing that fixed the problem:\n\n```\n  0%|                         | 1/530640 [00:06<903:18:25,  6.13s/it, loss=10.824, tps=1,003]\n5024998356359528747\n-4835662927486742764\n  0%|                         | 2/530640 [00:06<397:16:33,  2.70s/it, loss=10.785, tps=1,914]\n6231090084827524676\n8293831317336780907\n  0%|                         | 3/530640 [00:06<228:14:32,  1.55s/it, loss=10.741, tps=2,791]\n7896237091035346857\n-7117477486466304738\n```\n\nThe hashes were changing! And even better, if you scroll to the right you'll see that loss was slowly dropping. After 10k or so iterations, I was seeing 0.000: I had my do-nothing \"LLM\" working.\n\n### Wrapping up\n\nA satisfying debugging journey -- and while I don't think I'll make this specific mistake in the future, I think that the parameter-hashing trick is actually a really useful trick for the toolbox. If you're uncertain as to whether your parameters are being updated, just looking at them probably won't help. But looking at their hashes can help you find out whether anything is changing.\n\nAnd I think that the pattern that I used to zoom in on it is a useful one, too. I always\ntrack loss, so it's a good starting point (indeed, seeing that it wasn't falling was what\ntold me that something was going wrong). But checking that it has a sane -- or ideally, as\nin this case, a meaningful -- value is a nice sanity check that we have a working loss function\nand a model that isn't doing something completely pathological. Moving on from there to\nchecking that some kind of gradients are flowing through is a solid next move (and might\nbecome increasingly interesting with deeper models where [they can vanish or explode](/2026/02/llm-from-scratch-32b-interventions-gradient-clipping)).\nThen finally we can check the parameters -- in particular, are they changing? [2](#fn-2)\n\nLet's see how many new tricks I pick up as I work through this LLM project.\n\n-\nI always forget that\n\n`jax.debug.print`\n\nexists -- I could have used that instead, and kept the JIT.[↩](#fnref-1) -\nSomething's slightly broken in my brain and I keep reading that as \"is our parameters changing\"\n\n[in George W. Bush's voice](https://www.goodreads.com/quotes/273762-rarely-is-the-question-asked-is-our-children-learning). Maybe I can stop that from happening by inflicting it on my readers instead. You're welcome.[↩](#fnref-2)", "url": "https://wpnews.pro/news/flax-debugging-making-a-hash-of-things", "canonical_source": "https://www.gilesthomas.com/2026/06/hashing-jax-parameters", "published_at": "2026-06-17 02:11:12+00:00", "updated_at": "2026-06-17 02:22:47.959045+00:00", "lang": "en", "topics": ["machine-learning", "large-language-models", "developer-tools"], "entities": ["JAX", "Flax", "NNX", "GPT-2"], "alternates": {"html": "https://wpnews.pro/news/flax-debugging-making-a-hash-of-things", "markdown": "https://wpnews.pro/news/flax-debugging-making-a-hash-of-things.md", "text": "https://wpnews.pro/news/flax-debugging-making-a-hash-of-things.txt", "jsonld": "https://wpnews.pro/news/flax-debugging-making-a-hash-of-things.jsonld"}}