{"slug": "first-looking-into-jax", "title": "First Looking into Jax", "summary": "A developer who has spent years working extensively with PyTorch has begun exploring JAX, a competing machine learning framework, and reports finding it cleaner and more mathematically pure. The author contrasts JAX's functional, just-in-time compiled approach with PyTorch's procedural, piecewise-optimized design, arguing the difference reflects a fundamental divide between mathematical and engineering philosophies in deep learning tooling.", "body_md": "## On first looking into JAX\n\nMuch have I travell'd in the realms of gold,\n\nAnd many goodly states and kingdoms seen;\n\nRound many western islands have I been\n\nWhich bards in fealty to Apollo hold.\n\nOft of one wide expanse had I been told\n\nThat deep-brow'd Homer ruled as his demesne;\n\nYet did I never breathe its pure serene\n\nTill I heard Chapman speak out loud and bold:\n\nThen felt I like some watcher of the skies\n\nWhen a new planet swims into his ken;\n\nOr like stout Cortez when with eagle eyes\n\nHe star'd at the Pacific -- and all his men\n\nLook'd at each other with a wild surmise --\n\nSilent, upon a peak in Darien.\n\nJohn Keats,On First Looking into Chapman's Homer\n\nI've been working with [PyTorch](https://pytorch.org/) quite a lot for the last couple of years, and feel\nlike I've come to a reasonably solid understanding of how it all fits together.\n[Working through](/llm-from-scratch) [Sebastian Raschka](https://sebastianraschka.com/)'s book\n\"[Build a Large Language Model (from Scratch)](https://www.manning.com/books/build-a-large-language-model-from-scratch)\",\ntraining my own LLMs [locally](/2025/12/llm-from-scratch-28-training-a-base-model-from-scratch) and [in the cloud](/2026/04/llm-from-scratch-32j-interventions-trying-to-train-a-better-model-in-the-cloud),\n[rebuilding Andrej Karpathy's 2015-vintage RNNs](/2025/10/retro-language-models-rebuilding-karpathys-rnn-in-pytorch) --\nover time, it all adds up!\n\nBut, of course, there are other frameworks, and one I kept hearing about was\n[JAX](https://docs.jax.dev/en/latest/index.html). While it's less\ndominant than PyTorch, it has a reputation for a certain cleanliness, a certain purity.\nAnd having spent time over the last couple of weeks working through the tutorials, and translating small PyTorch examples\ninto it, I've been really impressed.\n\nIn this post I want to give an overview -- to report back to beginners like me, still living in PyTorch-land, on my new discovery. Less like Herschel discovering Uranus, and more like a 16th-century European coming back after having discovered something that the people who lived there were perfectly well aware of. What is this JAX thing, and how does it differ from PyTorch?\n\n### Some theses, significantly overstated\n\nI think that the main differences between PyTorch and JAX are something like this, but a little less strident:\n\n- PyTorch is engineering; JAX is maths.\n- PyTorch has historically\nbeen optimised piecewise, JAX is JITted.[1](#fn-1) - PyTorch is procedural, JAX (tries to be) functional.\n- PyTorch is maximalist; JAX is minimalist.\n\nHaving overstated my claims, let me dig in and perhaps walk them back a bit. Once I've gone through them, I'll do a walkthrough of porting a simple PyTorch training loop to JAX, which should illustrate the points well.\n\nFinally, I'll wrap up with the counterargument. JAX is wonderful and shiny, and 30+ years of industry experience and cynicism makes me fear that it might be doomed :-(\n\nBut let's start with the positive! [Happy face on.]\n\n### 1. Maths versus engineering\n\nA simple example that nicely contrasts the different philosophies of the two frameworks is what the core of a training loop looks like. Here's how you might write one in PyTorch:\n\n```\n    optimizer.zero_grad()\n\n    result = model(inputs)\n    loss = loss_function(result, targets)\n\n    loss.backward()\n\n    optimizer.step()\n```\n\nThis is kind of mechanistic. You're telling the computer what to do, step by step:\n\n- Zero out the gradients that you currently have attached to the parameters.\n- Do a forward pass to get the model's outputs.\n- Work out the loss based on those outputs.\n- Do the backward pass.\n- Update the parameters based on the gradients that the backward pass attached to them.\n\nNow let's look at a parallel JAX implementation:\n\n``` python\ndef calculate_loss(parameters, inputs, targets):\n    result = forward(parameters, inputs)\n    return loss_function(result, targets)\n\n...\n\ndef train():\n    ...\n\n    grads = jax.grad(calculate_loss)(layers, inputs, targets)\n\n    layers = step(layers, grads, learning_rate)\n```\n\nIt's clearly very different. No explicit backward pass, no gradient-zeroing, and the forward pass and loss calculation are baked into a separate function. But why is it shaped that way?\n\nLet's think about what we're actually doing in our training loop. The gradients are the partial derivative of the loss function against the weights :\n\nNow, I'm being a bit sloppy with that notation, because is a function, and it -- in the mathematical formulation -- takes the weights as a parameter. So it would be better written like this:\n\nBut that's still not quite right. In a real training loop, we're doing this in the\ncontext of a particular input batch, ,\nand its associated targets, . 2 We might write that mathematically as this:\n\n...where you can read the colon as \"given\". Now let's look again at the JAX code to work out the gradients:\n\n```\ngrads = jax.grad(calculate_loss)(layers, inputs, targets)\n```\n\nThat's an almost-perfect mirror of the maths!\n\nThe `jax.grad`\n\nfunction takes a function `f`\n\n, and returns another function, `g`\n\n,\nwhich takes the same arguments. When you call `g`\n\n, instead of returning the result\nof `f`\n\n, it will return the derivative of `f`\n\nwith respect to its first argument,\ngiven the values of the others. [3](#fn-3)\n\nHow is it doing that magic? Let's look at a simple concrete example:\n\n``` python\ndef f(x, y):\n    print(f\"In the function {x=}, {y=}\")\n    return x + y\n```\n\nIf you do the initial call to `grad`\n\n:\n\n```\ng = jax.grad(f)\n```\n\n...then it just wraps `f`\n\nin a helper function. It's when you call `g`\n\nthat the magic\nhappens.\n\n```\ng(2.0, 1.0)\n```\n\n...will print out this:\n\n```\nIn the function x=GradTracer(primal=2.0, typeof(tangent)=f32[]), y=1.0\n```\n\nThe first parameter -- the one with respect to which we're asking for the derivative --\nis replaced by a `GradTracer`\n\nobject. Because it's wrapping a float, it can\nbe used like one, so the function executes as expected. But it also keeps track of what happens to this\nvariable as the code executes, and essentially builds up what in PyTorch would be\nrepresented by the computation graph.\n\nSo: while in PyTorch, the variables that you pass in to a function that you need gradients\nfor need to be special PyTorch objects that can keep a reference to those gradients --\nthe `requires_grad`\n\nparameter that pops up frequently in PyTorch code -- in JAX, it's\nall handled by variables being automatically wrapped in these special tracers.\n\nOnce it has the results of the function as a whole, including the chain of operations that was traced, it can automatically do a backward pass, and we're done. That's really nifty!\n\nNow, the example above was a toy one, with just one parameter.\nIn a real training loop, you're differentiating against a set of weights, and\nthose will be something more complex. But `grad`\n\nhandles that gracefully. Let's see what happens if we pass in an array\nas the first parameter:\n\n``` python\n>>> import jax\n>>> import jax.numpy as jnp\n>>> def f(x, y):\n...     print(f\"In the function {x=}, {y=}\")\n...     return (x + y).sum()\n...\n>>> g = jax.grad(f)\n>>> g(jnp.array([1., 2., 3.]), jnp.array([4., 5., 6.]))\nIn the function x=GradTracer(primal=[1. 2. 3.], typeof(tangent)=f32[3]), y=Array([4., 5., 6.], dtype=float32)\nArray([1., 1., 1.], dtype=float32)\n```\n\nSo, we've got partial derivatives with respect to the elements of the array that was the first parameter -- just what we'd need for a single-layer neural network without bias.\n\nBut what about something more complicated? For something like (say) an LLM, we have quite a lot of structure to our weights: our input embeddings, output head, all of the layers with their attention and feed-forward weights, and so on.\n\n`grad`\n\nhandles that by understanding basic Python structures --\nthings that can be mapped to what JAX calls PyTrees. PyTrees are nested tree structures of dictionaries,\nlists, tuples and so on, where the leaves are numbers or JAX arrays 4.\n\nIf you ask for gradients of a variable that can be represented by a PyTree, you get them back in a form that mirrors that PyTree:\n\n``` python\n>>> def f(x, y):\n...     print(f\"In the function {x=}, {y=}\")\n...     return (x[\"a\"][\"b\"] + y).sum()\n...\n>>> g = jax.grad(f)\n>>> g({\"a\": {\"b\": jnp.array([1., 2., 3.])}}, jnp.array([4., 5., 6.]))\nIn the function x={'a': {'b': GradTracer(primal=[1. 2. 3.], typeof(tangent)=f32[3])}}, y=Array([4., 5., 6.], dtype=float32)\n{'a': {'b': Array([1., 1., 1.], dtype=float32)}}\n```\n\nIf you combine that with JAX's tree-aware `map`\n\nfunction, you can combine those gradients with the\noriginal parameters to update them as you train. I'll show you how that works later on, when we go through an\nexample of porting some PyTorch code to JAX.\n\nSo, all of that cool stuff was made possible by the tracer objects, which are passed in instead of the real parameters, and keep track of the computation graph (just like the graph that PyTorch attaches directly to the variables).\n\nBut tracers are more generally useful than that; they really come into their own with the next JAX difference: the JIT.\n\n### 2. JIT vs piecewise optimisation\n\nImagine that you've built some kind of nifty model in PyTorch. As part of it, you do a calculation something like this:\n\nYou decide that this is generally useful, so you\n[code it up as a CUDA kernel](https://huggingface.co/kernels/erikkaum/maxsim) and make\nit available to the community, like Erik Kaunismäki has with his \"MaxSim\" kernel. Maybe later on, it will get\nadded to the PyTorch library as a standard component.\n\nThere are a lot of optimisations like that built into PyTorch; people found that there\nwere higher-level abstractions on top of basic tensor operations that were generally useful,\nso they coded up lower-level optimised versions. For example, in the LLM I've been\nworking with, there is [an implementation of LayerNorm](/2025/07/llm-from-scratch-16-layer-normalisation).\nBut PyTorch has [its own one built in](https://docs.pytorch.org/docs/2.12/generated/torch.nn.LayerNorm.html).\nAnd there's a [CUDA implementation](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/layer_norm_kernel.cu) that\nit will use automatically if it has the appropriate hardware available.\n\nThere is a problem, though. Imagine that someone else is working on a different kind of model in the future. And for reasons completely unrelated to the MaxSim calculations that Kaunismäki nicely optimised, they happen to need to do the same calculations.\n\nNow, there are two things that can happen from there:\n\n- They don't know that the MaxSim kernel exists, so their code remains unoptimised.\n- They do know that it exists, so they repurpose it for whatever their use case is.\n\nThe first is not ideal; but the second isn't great either, if what they're using it for is not a MaxSim operation in reality, just something that happens to look the same mathematically.\n\nIn the general case: all optimisations that get into PyTorch have to be carefully named so that they reflect the exact level of abstraction that they're targeting. And when people are writing PyTorch models, they need to actually know which optimised abstractions are available, and where to apply them.\n\nNow let's look at JAX.\n\nIt has an innocuous-looking decorator, `jit`\n\n, and you can use it by adding a single\nline before your function:\n\n``` python\n@jax.jit\ndef selu(x, alpha=1.67, lambda_=1.05):\n    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n```\n\nBehind that single line is a huge amount of useful infrastructure. Just like\n`grad`\n\n, it's a function that takes one function and returns another, without necessarily running\nthe underlying code. 5 But when you call the wrapped function for the first time, some impressive stuff\nhappens:\n\n```\nselu(1.234)\n```\n\nThis will essentially execute the `selu`\n\ncode twice:\n\nThe first time through, it will create another of those tracer objects; this time, though, it won't wrap the number\n\n`1.234`\n\n-- it will just know that it is a wrapper for a float. It will call the Python code with that tracer, and all of the operations in the function will be run, but the result that comes out at the end will essentially just be a representation of what calculations were done in an abstract sense -- like the computation graph that was used for working out gradients, but without specific numbers in it.JAX has a nice way to display these representations as what it calls JAXPRs, and the JAXPR for that function's representation when called with a float parameter will look something like this:\n\n``` js\n    { lambda ; a:f32[]. let\n    b:bool[] = gt a 0.0:f32[]\n    c:f32[] = exp a\n    d:f32[] = mul 1.67:f32[] c\n    e:f32[] = sub d 1.67:f32[]\n    f:f32[] = jit[\n      name=_where\n      jaxpr={ lambda ; b:bool[] a:f32[] e:f32[]. let\n          f:f32[] = select_n b e a\n        in (f,) }\n    ] b a e\n    g:f32[] = mul 1.05:f32[] f\n  in (g,) }\n```\n\nThat JAXPR can be compiled into the appropriate code for the platform where you're running it -- x86 machine code, compiled CUDA, the equivalent for AMD or Google Tensor Processing Units (TPUs), and will be cached. The key for the cache will be meta-information about the parameter -- in this case, something like \"a 32-bit floating-point scalar\".\n\nNext, the compiled code -- not the original Python -- is run with the actual value of the parameter, the\n\n`1.234`\n\nthat we provided.\n\nNow, of course, the advantage of doing this is that when you call it with a different\nfloating-point number -- say, `5.678`\n\n-- then you don't need to do the compilation again.\nYou can just rely on the cached version. And the fact that the compiled code is\ncached based on the metadata means that if you call `selu`\n\nwith a vector, then it will compile\na new version for that, and likewise for a matrix version. [6](#fn-6)\n\nThis is all really nifty, and you can see how it would help right away. But for me,\nat least, an excellent extra benefit is how it can save people like Erik Kaunismäki the bother\nof writing custom kernels. The compilation that happens, taking the representation\nthat it got from the tracing process and turning it into backend code, goes through an\noptimising compiler, [XLA](https://openxla.org/). And that compiler can recognise\n\"standard\" operations and combine them together.\n\nThis won't be at the level of \"standard operations\" like MaxSim, of course -- more, \"this looks like a convolution, let's use the standard kernel\". But it does mean that instead of someone having to take code written in Python and hand-port it over to CUDA to get a GPU speedup, the same expertise can be put into improving the optimisation part of XLA to get a speedup for all code.\n\nThat's pretty amazing. However...\n\n### 3. Procedural vs functional code\n\nIf you want something like the JIT to work properly, you need to limit the kind of code that it works with. In particular, it needs to be functional. A function must always return the same value when given the same inputs -- so this is fine:\n\n``` python\n@jax.jit\ndef add(x, y):\n    return x + y\n\nprint(add(1, 2))\nprint(add(1, 3))\n```\n\n...but this will cause problems:\n\n``` python\n@jax.jit\ndef addY(x):\n    return x + y\n\ny = 2\nprint(addY(1))\ny = 3\nprint(addY(1))\n```\n\n...because `y`\n\ncould be changed. Specifically -- because the global `y`\n\nhad the value\n`2`\n\nduring the initial traced run of the function, that value will essentially get hard-coded\ninto the cached JITted version, so both prints in the second example will output `3`\n\n.\n\nSomething slightly surprising comes out of this -- something that makes JAX code look very different to PyTorch. How we handle randomness needs to completely change.\n\nConsider this code:\n\n``` python\nimport random\n\ndef f(x):\n    return x + random.randint(1, 10)\n\nrandom.seed(42)\n\nprint(f(1))\nprint(f(1))\n```\n\nAs a whole, it's deterministic. But it breaks the functional requirement that\nthe function can only depend on its inputs. Both calls to `f`\n\ntake the same input,\nbut they return different results.\n\nEven worse, if we were to do something that consumed\nrandomness between those two calls to `f`\n\n, for example:\n\n```\nprint(f(1))\nrandom.randint(1, 10)\nprint(f(1))\n```\n\n...we'd get different results. The state of the random number generator is\nglobal state kept outside the function, just like `y`\n\nin the `addY`\n\nexample above.\n\nA naive solution to this might be to make the state of the RNG explicit as a variable -- you can imagine a library that worked something like this:\n\n``` python\nimport updated_random\n\ndef f(x, random_state):\n    return x + random_state.randint(1, 10)\n\nrandom_state = updated_random.new_state(42)\nprint(f(1, random_state))\nprint(f(1, random_state))\n```\n\nThat *looks* more functional, but when you think\nabout it, we haven't actually fixed the problem. We're passing the same `random_state`\n\nvariable in in both cases, along with the same number, but we're getting different results.\nIt's not global, but it's still mutable behind the scenes.\n\nWhat you'd actually need to do to make it purely functional would be something like this:\n\n``` python\nimport updated_random\n\ndef f(x, random_state):\n    new_state, randint = updated_random.randint(random_state, 1, 10)\n    return new_state, x + randint\n\ninitial_random_state = updated_random.new_state(42)\nfirst_call_random_state, result = f(1, initial_random_state)\nprint(result)\nsecond_call_random_state, result = f(1, first_call_random_state)\nprint(result)\n```\n\nThe `updated_random.randint`\n\nfunction is generating a new random integer and returning\nboth that and the new state of the RNG, then we pass that back along with our\nresult. We've made the random state variables immutable, and so it's functional. But the\nAPI is getting pretty ugly pretty quickly.\n\nSo JAX does something that is\nequivalent, but a bit cleaner. There's a concept of a *key*, which needs to be passed\ninto any function that consumes randomness:\n\n```\nkey = jax.random.key(42)\n```\n\nThat's kind of like the `random_state`\n\nthat we have in the first version of the code above.\nBut it's immutable; when you use it, like this:\n\n```\njax.random.randint(key, (), 1, 11)\n```\n\n...it will not be changed, so no matter how many times you call it with the same\nkey, that function will return the same value. (Note that `jax.random.randint`\n\ntakes an inclusive lower bound and an exclusive upper bound, like Python's `range`\n\n,\nbut unlike the stdlib's `random.randint`\n\n. It also needs to know the shape of the\nresult -- `()`\n\nfor a scalar, `(1, 2)`\n\nfor a 1x2 array, and so on.)\n\nIf you want it to \"move on\" to a new state, you use the `split`\n\nfunction, which\ntakes an existing key and returns two (or more) new ones. So you can do something like this:\n\n``` python\nimport jax.random\n\ndef f(x, key):\n    return x + jax.random.randint(key, (), 1, 11)\n\ninitial_key = jax.random.key(42)\nfirst_call_key, new_key = jax.random.split(initial_key)\nprint(f(1, first_call_key))\nsecond_call_key, new_new_key = jax.random.split(new_key)\nprint(f(1, second_call_key))\n```\n\nNow, that `new_key`\n\nand `new_new_key`\n\nstuff is a bit ugly, but while it's not OK\nto mutate the contents of variables in functional code, it's absolutely fine to assign a new value to\nan existing one, so what I've found myself doing is writing stuff like this:\n\n``` python\nimport jax.random\n\ndef f(x, key):\n    return x + jax.random.randint(key, (), 1, 11)\n\nkey = jax.random.key(42)\nfirst_call_key, key = jax.random.split(key)\nprint(f(1, first_call_key))\nsecond_call_key, key = jax.random.split(key)\nprint(f(1, second_call_key))\n```\n\nHowever, there are more powerful ways to use `split`\n\n; I'm not confident enough at\nusing it yet to go into that, though, so I'll hold back for now. I suspect (assuming I keep using JAX) I'll be\nposting about them in the future.\n\nOK: so the JIT means that we have to write functional code, which makes things a bit fiddly -- no more global state. And that has a surprisingly big knock-on effect with randomness.\n\nBut there's another thing that comes out of the JIT and the way it does tracing. It's not a functional thing (though some of the docs seem to almost be treating it that way), but is caused by the same kind of constraints. It's not part of my four theses above, but I think it's important enough to call out in its own subsection.\n\n### 3.5. Control flow and values\n\nImagine this function:\n\n``` python\n@jax.jit\ndef f(x):\n    if x > 2:\n        return x ** 2\n    return x\n\nprint(f(10.0))\n```\n\nIt's purely functional, so no problem there.\nBut let's think about what the JIT is trying to do. It wants to convert the function\ninto a simple sequence of operations, so it will create a tracer for\na floating-point scalar, then call `f`\n\nwith it.\n\nWhen it hits that `if`\n\nstatement, there will be\na problem. The tracer is meant to represent any arbitrary float, so should it take\nthe `if`\n\nbranch or not? There's no good answer. It doesn't know which branch to follow\n-- whether the sequence should be \"square it and return the result\" or just\n\"return it directly\" --\nand will fail with a somewhat obscure error message:\n\n```\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\n```\n\nSo this gives a hard constraint on functions that you want to JIT: by default, they can't base control flow on the values you pass in. There is a workaround -- but it comes with tradeoffs. Let's take a slightly sideways route to explain it.\n\nFirstly, although you cannot do control flow based on the *value* of a parameter -- which\nthe tracer doesn't know -- you\ncan base it on other information that actually is stored in the tracer. Let's say that we called\n`f`\n\nlike this:\n\n```\nf(jax.numpy.array([[1., 2.], [3., 4.]]))\n```\n\nThe tracer that would be passed in when trying to trace the function would be\nsomething representing a 2x2 array. The *shape* of the parameter is part of the tracer, even\nthough the values aren't. So you could do something like this:\n\n``` python\n@jax.jit\ndef f(x):\n    if len(x.shape) > 1:\n        return x ** 2\n    return x\n```\n\n...and it would work. It's worth thinking explicitly why this is.\n\nWhen you call a JITted function, it will create a tracer that contains information about the type of thing you passed in as a parameter -- scalar versus array, and if it's an array, the array's shape. It then runs the function with the tracer, gets the sequence of operations, compiles them and then stores the result in a cache keyed on the metadata -- type and, if appropriate, shape -- that it used to create the tracer.\n\nSo when we call that function with a 2x2 array, we get a 2x2 array version, then if we call it later with a one-dimensional array of length 2, we'll get a new version for that.\n\nOne workaround for basing control flow on values is essentially to tell the `jit`\n\nfunction\nthat it should treat the values of a particular variable as being like the metadata used\nfor this cache keying: it should\ncompile a new version for each value it sees, rather than just using the metadata. It\ntakes a parameter `static_argnums`\n\n, and a matching `static_argnames`\n\n, which tell it\nwhich parameters to do that with. So, this will work:\n\n``` python\nfrom functools import partial\n\n@partial(jax.jit, static_argnums=(0,))\ndef f(x):\n    if x > 2:\n        return x ** 2\n    return x\n\nprint(f(10.0))\n```\n\n(Remember that the thing after the `@`\n\nfor a decorator needs to be a function that\nreturns a function, so we have to use `partial`\n\nto \"inject\" in the extra argument.)\n\nHowever, the downside is pretty clear: every time we call `f`\n\nwith a new value, it's\ngoing to have to JIT a new version of the function and cache it -- that's going to be slow\nand take up memory.\n\nSo, as an alternative, we can use [the jax.lax package](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators).\nThis provides more functional-looking alternatives for control flow, which\n\n*are*compatible with the way the JIT works. For example, there's a\n\n`cond`\n\nfunction, which\nwe can use to replace `if`\n\ns:\n\n``` python\n@jax.jit\ndef f(x):\n    return jax.lax.cond(x > 2, lambda: x ** 2, lambda: x)\n\nprint(f(10.0))\n```\n\nThat feels a little bit like a workaround, but it does solve the problem. How? Well, it's worth checking the JAXPR for it:\n\n``` js\n>>> jax.make_jaxpr(f)(10.0)\n{ lambda ; a:f32[]. let\n    b:f32[] = jit[\n      name=f\n      jaxpr={ lambda ; a:f32[]. let\n          c:bool[] = gt a 2.0:f32[]\n          d:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c\n          b:f32[] = cond[\n            branches=(\n              { lambda ; e:f32[]. let  in (e,) }\n              { lambda ; f:f32[]. let g:f32[] = integer_pow[y=2] f in (g,) }\n            )\n          ] d a\n        in (b,) }\n    ] a\n  in (b,) }\n```\n\nWhat's happened here, I think, is that the JIT has recognised the call to `jax.lax.cond`\n\nas\nbeing a primitive function in its intermediate language, so has just kept it in there. It\ncouldn't do that with the `if`\n\nbecause when it was tracing, all JAX itself saw was what was\nhappening to the tracer -- there was a boolean comparison, and then the stuff in the chosen\nbranch happened. The fact that there was an `if`\n\nthere happened in Python itself, outside\nJAX, so it was \"invisible\" to the trace.\n\nThat feels a little inelegant to me right now, and I'll come back to it later.\n\nLet's move on to the final difference between the two libraries that I want to cover: JAX's relative minimalism to PyTorch's more maximalist approach.\n\n### 4. Minimalism versus maximalism\n\nI think the smaller size of JAX -- at least in terms of its API, if not in terms of the JIT and XLA magic under the hood -- compared to the sprawl of PyTorch is not entirely unrelated to the JIT being at its core.\n\nPyTorch, after some initial design, has almost been forced to grow organically; JAX feels more carefully designed,\nso it doesn't have the same *need* to grow (though of course it can).\n\nThe reason for PyTorch's growth is, at least in part, because\nit needs to absorb optimisations. If something is slow, someone needs to write a CUDA\nkernel for it. If there's a CUDA kernel, it needs an API. And if it is generally useful, that API becomes part of\nPyTorch. Multi-head attention? [There's a class for that](https://docs.pytorch.org/docs/2.12/generated/torch.nn.MultiheadAttention.html).\nSELU? [Yup](https://docs.pytorch.org/docs/2.12/generated/torch.nn.SELU.html).\nVery specific softmax approximations based on a paper published in 2016?\n[PyTorch has you covered](https://docs.pytorch.org/docs/2.12/generated/torch.nn.AdaptiveLogSoftmaxWithLoss.html).\n\nBy contrast, JAX doesn't even have linear layers or optimisers in the framework\nitself; if you want to use them, you can write them yourself (contraindicated), or\nyou can use [libraries built on top of JAX](https://docs.jax.dev/en/latest/#ecosystem), like\n[Flax](https://flax.readthedocs.io/en/stable/) for common neural network components\nand [Optax](https://optax.readthedocs.io/en/latest/) for optimisers.\n\nThis feels like a nice division of responsibilities, and it also seems like something\nthat would have been very hard without the JIT. So while the JAX core may well grow in the\nfuture, the design it has now puts it in a good position to grow in a more planned,\nwell-designed manner -- rather than *having* to grow to absorb more and more abstractions\njust to keep it fast. Those abstractions can more easily sit in libraries written on\ntop of JAX.\n\n### Porting a toy PyTorch model to JAX\n\nThat's the 10,000-foot overview; four (or maybe four and a half) main differences between PyTorch and JAX. It's more maths-y, JITted, functional and minimalist. What does that actually mean when you get down to coding with it? Let's get into the weeds with an example.\n\nLet's use a really simple one: training a neural network with two inputs and one hidden\nlayer to calculate the XOR function. The code is in [this GitHub repo](https://github.com/gpjt/toy-pytorch-to-jax),\nbut I'll put the relevant bits here in this post.\n\nFirstly, an idiomatic PyTorch implementation:\n\n``` python\nimport time\n\nimport torch\n\ndata = [\n    ([0., 0.], [0]),\n    ([0., 1.], [1]),\n    ([1., 0.], [1]),\n    ([1., 1.], [0]),\n]\n\nclass XORModel(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self.layer1 = torch.nn.Linear(2, 2, bias=True)\n        self.layer1_activation = torch.nn.Sigmoid()\n        self.layer2 = torch.nn.Linear(2, 1, bias=True)\n        self.layer2_activation = torch.nn.Sigmoid()\n\n    def forward(self, x):\n        hidden = self.layer1_activation(self.layer1(x))\n        output = self.layer2_activation(self.layer2(hidden))\n        return output\n\ndef calculate_loss(model, inputs, target):\n    result = model(inputs)\n    return ((result - target) ** 2).mean()\n\ndef main():\n    torch.manual_seed(42)\n\n    model = XORModel()\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n    start = time.time()\n    for epoch in range(10000):\n        losses = []\n\n        for x, y in data:\n            optimizer.zero_grad()\n\n            loss = calculate_loss(model, torch.tensor(x), torch.tensor(y))\n            loss.backward()\n            losses.append(loss.item())\n\n            optimizer.step()\n\n        if epoch % 1000 == 0:\n            avg_loss = sum(losses) / len(losses)\n            print(f\"Loss at epoch {epoch}: {avg_loss:.6f}\")\n    end = time.time()\n\n    print(f\"Trained in {end - start:.3f}s\")\n\n    print(f\"Loss at end: {avg_loss:.6f}\")\n\n    model.eval()\n    with torch.no_grad():\n        for x, y in data:\n            result = model(torch.tensor(x))\n            print(f\"{x=}: {result=}, {y=}\")\n\nif __name__ == \"__main__\":\n    main()\n```\n\nIf we run that, it trains a solid-looking model in about four seconds on my machine:\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pytorch_xor.py\nLoss at epoch 0: 0.279327\nLoss at epoch 1000: 0.254715\nLoss at epoch 2000: 0.254279\nLoss at epoch 3000: 0.253985\nLoss at epoch 4000: 0.253649\nLoss at epoch 5000: 0.251566\nLoss at epoch 6000: 0.189219\nLoss at epoch 7000: 0.030093\nLoss at epoch 8000: 0.006666\nLoss at epoch 9000: 0.003516\nTrained in 4.154s\nLoss at end: 0.003516\nx=[0.0, 0.0]: result=tensor([0.0483]), y=[0]\nx=[0.0, 1.0]: result=tensor([0.9567]), y=[1]\nx=[1.0, 0.0]: result=tensor([0.9425]), y=[1]\nx=[1.0, 1.0]: result=tensor([0.0434]), y=[0]\n```\n\nNow, if we're porting to JAX we need to do something about the fact that JAX doesn't have optimisers and the neural network stuff built in. If this was a real codebase, we'd almost certainly do that by using the libraries built on top of JAX, like Flax and Optax. But for this toy example, I think it's more illustrative to strip down the PyTorch version so that it uses fewer parts of the API -- essentially so that it only uses the stuff that JAX has -- and then to port the result.\n\nThe optimiser first. [The code is here](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch_xor_no_optimizer.py)\nbut the diffs are pretty simple. Instead of creating an optimiser, we just\nspecify our learning rate:\n\n```\n<     optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n---\n>     learning_rate = 0.1\n```\n\nInstead of zeroing out the gradients using the optimiser, we can just ask the model to do it:\n\n```\n<             optimizer.zero_grad()\n---\n>             model.zero_grad()\n```\n\nAnd instead of stepping the optimiser, we call a new `step`\n\nfunction passing in\nthe model and the learning rate:\n\n```\n<             optimizer.step()\n---\n>             step(model, learning_rate)\n```\n\nThe `step`\n\nfunction is simple enough; we just switch into `no_grad`\n\nmode so that PyTorch\ndoesn't try to track the computation graph (working out gradients for applying gradients\nand triggering some kind of crazy gradient-ception), then we just iterate over the model's\nparameters and follow the normal SGD process, subtracting the\ngradients times the learning rate:\n\n``` python\ndef step(model, learning_rate):\n    with torch.no_grad():\n        for p in model.parameters():\n            if p.grad is not None:\n                p -= p.grad * learning_rate\n```\n\nRunning that on my machine actually works out slightly faster than the original 7!\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pytorch_xor_no_optimizer.py\nLoss at epoch 0: 0.279327\nLoss at epoch 1000: 0.254715\nLoss at epoch 2000: 0.254279\nLoss at epoch 3000: 0.253985\nLoss at epoch 4000: 0.253649\nLoss at epoch 5000: 0.251566\nLoss at epoch 6000: 0.189219\nLoss at epoch 7000: 0.030091\nLoss at epoch 8000: 0.006665\nLoss at epoch 9000: 0.003516\nTrained in 3.806s\nLoss at end: 0.003516\nx=[0.0, 0.0]: result=tensor([0.0483]), y=[0]\nx=[0.0, 1.0]: result=tensor([0.9567]), y=[1]\nx=[1.0, 0.0]: result=tensor([0.9425]), y=[1]\nx=[1.0, 1.0]: result=tensor([0.0434]), y=[0]\n```\n\nIt's also quite nice to see that (within the bounds of the printing precision) the loss and the final results are identical.\n\nOK, so now that we've got rid of the optimiser, let's do the same with the\n`nn.Linear`\n\ns. [Here's the code](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch_xor_no_nn_helpers.py),\nbut let's do a quick walk through the differences.\n\nInstead of creating an `XORModel`\n\n, we will just generate an array of layers:\n\n```\n<     model = XORModel()\n---\n>     layers = [\n>         generate_layer_parameters(2, 2),\n>         generate_layer_parameters(2, 1),\n>     ]\n```\n\nZeroing out the existing gradients will also need to be done on those layers:\n\n```\n<             model.zero_grad()\n---\n>             zero_grad(layers)\n```\n\n...and likewise our loss calculations and the `step`\n\nfunction will need to use them:\n\n```\n<             loss = calculate_loss(model, torch.tensor(x), torch.tensor(y))\n---\n>             loss = calculate_loss(layers, torch.tensor(x), torch.tensor(y))\n58c76\n<             step(model, learning_rate)\n---\n>             step(layers, learning_rate)\n```\n\nWe used a couple of new helper functions there; this one generates the\ninitial weights for the layers (based on the [docs for torch.nn.Linear](https://docs.pytorch.org/docs/2.12/generated/torch.nn.Linear.html)):\n\n``` python\ndef generate_layer_parameters(d_in, d_out):\n    root_k = math.sqrt(1. / d_in)\n    weights = (torch.rand(d_out, d_in) * 2 * root_k) - root_k\n    biases = (torch.rand(d_out) * 2 * root_k) - root_k\n    return {\n        \"weights\": weights.requires_grad_(),\n        \"biases\": biases.requires_grad_(),\n    }\n```\n\nNote that each of the tensors we created, the `weights`\n\nand the `biases`\n\nneed to be\nexplicitly told, using `requires_grad_`\n\n, that we're going to want PyTorch to track\ngradients on them.\n\nZeroing out the gradients is just a case of chugging through each layer, and then for each\nsetting the weights' and the biases' gradients to `None`\n\n:\n\n``` python\ndef zero_grad(layers):\n    for layer in layers:\n        for p in (layer[\"weights\"], layer[\"biases\"]):\n            p.grad = None\n```\n\nNow, to calculate the loss, we're actually not changing much. We had this:\n\n``` python\ndef calculate_loss(model, inputs, target):\n    result = model(inputs)\n    return ((result - target) ** 2).mean()\n```\n\n...and now we just change it to this:\n\n``` python\ndef calculate_loss(layers, inputs, target):\n    result = forward(layers, inputs)\n    return ((result - target) ** 2).mean()\n```\n\nThat is, we've added on a new function `forward`\n\nto do a forward pass through the\ngiven layers with the given parameters. That looks like this:\n\n``` python\ndef forward(layers, inputs):\n    x = inputs\n    for layer in layers:\n        x = torch.sigmoid(\n            x @ layer[\"weights\"].T + layer[\"biases\"]\n        )\n    return x\n```\n\nA quick tweak to use `forward`\n\nin the printing of the results at the end:\n\n```\n<             result = model(torch.tensor(x))\n---\n>             result = forward(layers, torch.tensor(x))\n```\n\n...and we're done!\n\nLet's run it:\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pytorch_xor_no_nn_helpers.py\nLoss at epoch 0: 0.279327\nLoss at epoch 1000: 0.254715\nLoss at epoch 2000: 0.254279\nLoss at epoch 3000: 0.253985\nLoss at epoch 4000: 0.253649\nLoss at epoch 5000: 0.251566\nLoss at epoch 6000: 0.189218\nLoss at epoch 7000: 0.030092\nLoss at epoch 8000: 0.006665\nLoss at epoch 9000: 0.003516\nTrained in 3.504s\nLoss at end: 0.003516\nx=[0.0, 0.0]: result=tensor([0.0483]), y=[0]\nx=[0.0, 1.0]: result=tensor([0.9567]), y=[1]\nx=[1.0, 0.0]: result=tensor([0.9425]), y=[1]\nx=[1.0, 1.0]: result=tensor([0.0434]), y=[0]\n```\n\nEven faster! Sounds like there aren't any nice pre-baked optimisations in that part of PyTorch, then... But again, within the bounds of our precision, that's exactly the same numbers as we got from the original PyTorch version, which is very reassuring.\n\nOK, now that we've got something that's kind of JAX-shaped, let's port it over. I think\nit's worth showing all of the code for that (though it's\n[here on GitHub](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pure_jax_xor_no_jit.py)\nif you want to view it there), and then I'll highlight the important diffs separately.\n\n``` python\nimport math\nimport time\n\nimport jax\nimport jax.numpy as jnp\n\njax.config.update(\"jax_platform_name\", \"cpu\")\n\ndata = [\n    ([0., 0.], [0]),\n    ([0., 1.], [1]),\n    ([1., 0.], [1]),\n    ([1., 1.], [0]),\n]\n\ndef generate_layer_parameters(key, d_in, d_out):\n    weight_key, bias_key = jax.random.split(key)\n    root_k = math.sqrt(1. / d_in)\n    weights = (jax.random.uniform(weight_key, shape=(d_out, d_in)) * 2 * root_k) - root_k\n    biases = (jax.random.uniform(bias_key, shape=(d_out,)) * 2 * root_k) - root_k\n    return {\n        \"weights\": weights,\n        \"biases\": biases,\n    }\n\ndef forward(layers, inputs):\n    x = inputs\n    for layer in layers:\n        x = jax.nn.sigmoid(\n            x @ layer[\"weights\"].T + layer[\"biases\"]\n        )\n    return x\n\ndef step(layers, grads, learning_rate):\n    layers = jax.tree.map(\n        lambda p, g: p - g * learning_rate,\n        layers,\n        grads,\n    )\n    return layers\n\ndef calculate_loss(layers, inputs, target):\n    result = forward(layers, inputs)\n    return ((result - target) ** 2).mean()\n\ndef main():\n    key = jax.random.key(42)\n\n    layer_1_key, layer_2_key = jax.random.split(key)\n    layers = [\n        generate_layer_parameters(layer_1_key, 2, 2),\n        generate_layer_parameters(layer_2_key, 2, 1),\n    ]\n\n    learning_rate = 0.1\n\n    start = time.time()\n    for epoch in range(10000):\n        losses = []\n\n        for x, y in data:\n            loss, grads = jax.value_and_grad(calculate_loss)(layers, jnp.array(x), jnp.array(y))\n            losses.append(loss.item())\n\n            layers = step(layers, grads, learning_rate)\n\n        if epoch % 1000 == 0:\n            avg_loss = sum(losses) / len(losses)\n            print(f\"Loss at epoch {epoch}: {avg_loss:.6f}\")\n\n    end = time.time()\n\n    print(f\"Trained in {end - start:.3f}s\")\n\n    print(f\"Loss at end: {avg_loss:.6f}\")\n\n    for x, y in data:\n        result = forward(layers, jnp.array(x))\n        print(f\"{x=}: {result=}, {y=}\")\n\nif __name__ == \"__main__\":\n    main()\n```\n\nIf you look at it side-by-side with [the previous PyTorch implementation](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch_xor_no_nn_helpers.py),\nyou'll see that it's really similar! Running `diff`\n\nbetween them makes them look\nmore different than they are because of the extra threading through of keys that we\nneed to do in order to satisfy the strict constraints on random number handling in JAX,\n(and of course there are function name changes like `torch.rand`\n\nbecoming `jax.random.uniform`\n\nand `torch.sigmoid`\n\nbecoming `jax.nn.sigmoid`\n\n). But the important changes are much smaller.\n\nFirstly, weights and biases no longer need to know that we'll want to track gradients for them, because that's all handled by the tracers that JAX wraps around them:\n\n```\n<         \"weights\": weights.requires_grad_(),\n<         \"biases\": biases.requires_grad_(),\n---\n>         \"weights\": weights,\n>         \"biases\": biases,\n```\n\nRelatedly, the `zero_grad`\n\nfunction that iterated over the layers and zeroed out\nthe existing ones is completely gone. Because gradients are now stored on tracers\nthat wrap around our parameters rather than on the parameters themselves, we don't need to zero them out.\n\nThe step function is still there, though, but it's much simpler. Before we get to that, let's take a look at the way we're getting the gradients for it, in the main training loop. Here's the diff:\n\n```\n<             loss = calculate_loss(layers, torch.tensor(x), torch.tensor(y))\n<             loss.backward()\n---\n>             loss, grads = jax.value_and_grad(calculate_loss)(layers, jnp.array(x), jnp.array(y))\n```\n\nHopefully the change there will be nice and familiar from the start of this post:\nwe've moved from the PyTorch procedural \"do a forward pass then do the backward pass\"\nto the JAX maths-y \"work out the gradients for this function\". `value_and_grad`\n\nis\na utility function that does the same as the `grad`\n\nwe encountered then, but rather than\njust returning the gradients, it also returns\nthe value of `calculate_loss`\n\nwith the given parameters,\nwhich is useful for our logging.\n\nNow, remember that `layers`\n\nis a list of dictionaries, something like this:\n\n```\n[\n    {\n        'biases': Array([-0.11810607, -0.58481467], dtype=float32),\n        'weights': Array([[-0.37359995,  0.6218162 ], [-0.4298191 ,  0.15088385]], dtype=float32)\n    },\n    {\n        'biases': Array([-0.49658495], dtype=float32),\n        'weights': Array([[-0.38409787,  0.6165393 ]], dtype=float32)\n    }\n]\n```\n\nAnd also remember that `grad`\n\n-- and likewise `value_and_grad`\n\n-- have that smart trick\nwhere they return the gradients in the same PyTree structure as the parameter that we're\ntaking the derivative with respect to. So `grads`\n\nwill also be a list of dictionaries,\neach of which has `weights`\n\nand `biases`\n\n.\n\nNow, as I mentioned earlier, JAX has a useful function called `jax.tree.map`\n\n. Like the [Python map](https://docs.python.org/3/library/functions.html#map)\nfunction that maps a function over one or more lists, JAX's version maps a function over one or\nmore things with the same PyTree structure.\n\nSo, because `layers`\n\nand `grads`\n\nhave the same\nstructure, our `step`\n\nfunction can just use it to apply simple gradient descent\nlike this:\n\n``` python\ndef step(layers, grads, learning_rate):\n    layers = jax.tree.map(\n        lambda p, g: p - g * learning_rate,\n        layers,\n        grads,\n    )\n    return layers\n```\n\nVery clean :-)\n\nThat's it! A full JAX implementation of our toy example, and when we run it:\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pure_jax_xor_no_jit.py\nAn NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\nLoss at epoch 0: 0.267455\nLoss at epoch 1000: 0.247348\nLoss at epoch 2000: 0.061305\nLoss at epoch 3000: 0.008652\nLoss at epoch 4000: 0.004108\nLoss at epoch 5000: 0.002627\nLoss at epoch 6000: 0.001912\nLoss at epoch 7000: 0.001496\nLoss at epoch 8000: 0.001224\nLoss at epoch 9000: 0.001034\nTrained in 104.540s\nLoss at end: 0.001034\nx=[0.0, 0.0]: result=Array([0.03008602], dtype=float32), y=[0]\nx=[0.0, 1.0]: result=Array([0.97214633], dtype=float32), y=[1]\nx=[1.0, 0.0]: result=Array([0.96557194], dtype=float32), y=[1]\nx=[1.0, 1.0]: result=Array([0.02664344], dtype=float32), y=[0]\n```\n\n...it works! So, let's move on to...\n\nHang on:\n\n```\nTrained in 104.540s\n```\n\nYikes. It was almost 30 times slower than the PyTorch version. But then -- we did all of that work to port the code over to JAX, which is great because it has a JIT, and then we didn't use the JIT. Whoops!\n\nAdding a few calls to `@jax.jit`\n\nhelps. If we add them to the `forward`\n\n, `step`\n\nand `calculate_loss`\n\nfunction then we get [this code](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pure_jax_xor_initial_jit.py),\nwhich is faster:\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pure_jax_xor_initial_jit.py\nAn NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\nLoss at epoch 0: 0.267455\nLoss at epoch 1000: 0.247348\nLoss at epoch 2000: 0.061305\nLoss at epoch 3000: 0.008652\nLoss at epoch 4000: 0.004108\nLoss at epoch 5000: 0.002627\nLoss at epoch 6000: 0.001912\nLoss at epoch 7000: 0.001496\nLoss at epoch 8000: 0.001224\nLoss at epoch 9000: 0.001034\nTrained in 27.663s\nLoss at end: 0.001034\nx=[0.0, 0.0]: result=Array([0.03008603], dtype=float32), y=[0]\nx=[0.0, 1.0]: result=Array([0.97214633], dtype=float32), y=[1]\nx=[1.0, 0.0]: result=Array([0.96557194], dtype=float32), y=[1]\nx=[1.0, 1.0]: result=Array([0.02664347], dtype=float32), y=[0]\n```\n\n...but it's still almost eight times slower than the PyTorch code.\n\nHow can we make it faster? Well, perhaps we can do more if we put more of the loop into the JITted stuff. Right now, the core of our training loop looks like this:\n\n```\n        for x, y in data:\n            loss, grads = jax.value_and_grad(calculate_loss)(layers, jnp.array(x), jnp.array(y))\n            losses.append(loss.item())\n\n            layers = step(layers, grads, learning_rate)\n```\n\n`calculate_loss`\n\nand `step`\n\nare JITted. But what happens if we try to JIT a larger\nstep? We can move the forward pass and the step into a JITted function on their own:\n\n``` python\n@jax.jit\ndef train_step(layers, inputs, targets, learning_rate):\n    loss, grads = jax.value_and_grad(calculate_loss)(layers, inputs, targets)\n    layers = step(layers, grads, learning_rate)\n    return layers, loss\n```\n\n...and then call it in the loop like this:\n\n```\n        for x, y in data:\n            layers, loss = train_step(layers, jnp.array(x), jnp.array(y), learning_rate)\n            losses.append(loss.item())\n```\n\nWith that, all of the JAX code apart from input and target wrangling is moved into a JITted function.\nWe get [this code](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pure_jax_xor_final_jit.py), and running it gives us this:\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pure_jax_xor_final_jit.py\nAn NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\nLoss at epoch 0: 0.267455\nLoss at epoch 1000: 0.247348\nLoss at epoch 2000: 0.061305\nLoss at epoch 3000: 0.008652\nLoss at epoch 4000: 0.004108\nLoss at epoch 5000: 0.002627\nLoss at epoch 6000: 0.001912\nLoss at epoch 7000: 0.001496\nLoss at epoch 8000: 0.001224\nLoss at epoch 9000: 0.001034\nTrained in 2.432s\nLoss at end: 0.001034\nx=[0.0, 0.0]: result=Array([0.03008603], dtype=float32), y=[0]\nx=[0.0, 1.0]: result=Array([0.97214633], dtype=float32), y=[1]\nx=[1.0, 0.0]: result=Array([0.96557194], dtype=float32), y=[1]\nx=[1.0, 1.0]: result=Array([0.02664347], dtype=float32), y=[0]\n```\n\nWoohoo! Almost 45% faster than the PyTorch version :-)\n\nSo: porting to JAX alone gives us nice maths-y code, but we need to JIT it properly to get performance that matches PyTorch. (The fact that it's faster than PyTorch in this case is not something that I think you could rely on -- this is, after all, a toy example.)\n\nIt's also an interesting indicator that you actually need to think about\nwhat to JIT. My initial thought, \"just whack an `@jit`\n\non the inner stuff\", was\nnot enough. We needed to do more than that. I've just had an interesting chat\nwith Claude Opus 4.8 about that, though, and will probably post more about it later.\nFor now, I think a useful rule-of-thumb is to wrap stuff in `@jit`\n\nat as high a level\nas you reasonably can, to maximise coverage.\n\nSo, this completes the happy part of this post -- I've shown what it can do, how nicely it maps to the maths, and how it's (relatively) easy to make it fast. What are the downsides?\n\n### Why JAX is doomed\n\nAnother deliberately overly-strident heading ;-)\n\nI've been programming for more than 40 years, and working professionally in the tech industry for more than 30. I'd like to feel that this makes me a better engineer than I was when I was first starting out, but I can confidently say that it has made me a much more cynical one.\n\nOver that period, I've come to categorise new APIs, languages, and tools into three approximate groups: godawful hacks, solid but not overly inspiring engineering, and things of beauty. They're loose categories, and most things are somewhere between one and another. But I think they hold reasonably well.\n\nMy cynicism and experience tells me that:\n\n- Horrible hacks can inexplicably become popular, but normally die off when people get tired of swearing at them. (Though sometimes a large installed base means that they linger.)\n- Things of beauty get people excited, and often pull in the best engineers. But eventually, they drop by the wayside. Perhaps there's some hidden flaw that no-one noticed at the outset, or perhaps the mental model you need to build in order to use them effectively is too complicated for them to get to critical mass.\n- Solid, boring engineering wins in the long term.\n\nWhen we were building our programmable spreadsheet, [Resolver One](/resolver-one),\nsome of the team pointed out that a functional language -- specifically,\n[Haskell](https://www.haskell.org/) -- would be a better fit than Python. It was\na tough decision to stick with Python, and I'm still not 100% sure it was the right one.\n\nBut I do remember having sales meetings with quants at various financial firms about it,\nand in those meetings, some of the potential customers also suggested a Haskell port.\nI'm not saying that there's a perfect correlation between where we heard that,\nand the later notes in our sales status spreadsheet saying\n[\"client being acquired by a non-bankrupt competitor, all expenditure on hold\"](/2008/11/do-one-thing-and-do-it-well)\nduring the 2008 financial crisis. But I'm not *not* saying that either.\n\nIf you've read this far, you can probably tell that I see PyTorch as solid engineering, and JAX as closer to a thing of beauty.\n\nMaybe it's just the cynicism of age, but let me try to articulate the things I worry might put JAX into the \"beautiful but doomed\" side of the \"beautiful\" category.\n\nFirstly, I'm not convinced by the way that JAX, with its JIT, requires you to try to write Python as if it were a functional language. It's easy enough to see that this isn't functional:\n\n``` python\n@jax.jit\ndef addY(x):\n    return x + y\n```\n\n...but harder with this:\n\n``` python\ndef f(x):\n    return x + random.randint(1, 10)\n```\n\nEven worse, the way that tracing works means that you have even more constraints than \"just\" being functional would require -- remember this example from earlier?\n\n``` python\n@jax.jit\ndef f(x):\n    if x > 2:\n        return x ** 2\n    return x\n```\n\nPython is not functional, and is deliberately so. Trying to make it so is always going\nto lead to weird bugs (for example, how the value of the global `y`\n\non the first run\nwould be baked into that `addY`\n\nfunction) and hard-to-understand error messages (you\nreally need to be clued-up to work out what `Attempted boolean conversion of traced array with shape bool[]`\n\nmeans).\n\nThe `jax.lax`\n\npackage -- for example, the `cond`\n\nfunction we used to\nwork around the fact that JAX could not \"see\" the Python `if`\n\nway back in this post --\nfeels like a bit of an ugly workaround. Python has control flow functions, but they\ndon't work with the JIT's tracing, so we have to re-implement them in JAX. Hmmm.\n\nNow, I've written extensively above about how JAX's restrictions, however confusing, enable a lot of the amazing stuff that wouldn't be possible in normal PyTorch. What if there were some way to write PyTorch code and compile it directly to something that can execute on the hardware?\n\nIt turns out that as of 2023, there is: [ torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).\nFrom what I understand, you're meant to be able to just attach it to your code and\nit gets JITted. But unlike JAX, you don't need to restrict the code you write. I've\nnot investigated in much depth (after all, this post is already absurdly long and\nhas taken more than a month on and off to put together), but it looks like it handles\nstuff that can't be compiled by using a concept of a \"graph break\" -- that is, it happily\nJITs what it can, then if it hits something that it can't JIT, it will cache the\n\"work so far\" as one compiled unit, run the Python code for the unJITable stuff, then\n(when it can) drop back into JIT mode.\n\nThe best of both worlds? I don't know, and would need to spend much more time investigating\nin order to learn. But I can say that for my minimal-effort\n[port of my toy XOR code](https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch_xor_with_compile.py),\nfollowing the structure of the JITted JAX version, it really did not help:\n\n``` bash\ngiles@perry:~/Dev/toy-pytorch-to-jax (main)$ uv run pytorch_xor_with_compile.py\nLoss at epoch 0: 0.279327\nLoss at epoch 1000: 0.254715\nLoss at epoch 2000: 0.254279\nLoss at epoch 3000: 0.253985\nLoss at epoch 4000: 0.253649\nLoss at epoch 5000: 0.251566\nLoss at epoch 6000: 0.189218\nLoss at epoch 7000: 0.030091\nLoss at epoch 8000: 0.006665\nLoss at epoch 9000: 0.003516\nTrained in 6.688s\nLoss at end: 0.003516\nx=[0.0, 0.0]: result=tensor([0.0483]), y=[0]\nx=[0.0, 1.0]: result=tensor([0.9567]), y=[1]\nx=[1.0, 0.0]: result=tensor([0.9425]), y=[1]\nx=[1.0, 1.0]: result=tensor([0.0434]), y=[0]\n```\n\nFor those who are keeping track, that's slower than the uncompiled version, which came in at about 3.5s. And the issue doesn't seem to be an up-front cost of JITting that would be paid off if we ran for more epochs -- each individual \"Loss at epoch XXX\" print comes out slower.\n\nAgain, for the sake of sanity I'm not going to dig into it further, especially given that\nthis is a tiny toy model and probably about as far from the target use case of\n`torch.compile`\n\nas you can get. But it's something\nwell worth noting for the future.\n\nStepping back: one other way of looking at this is that Python might just be the wrong language\nto try to build code that compiles to GPUs. I'm learning JAX right now so that I\ncan re-implement my existing [LLM from scratch](/llm-from-scratch) project in something\nother than PyTorch, to make sure that I really understand it. I\n[asked people on X/Twitter for votes or ideas](https://x.com/gpjt/status/1985434030880293004),\nand while JAX won, [Jeremy Howard suggested Mojo](https://x.com/jeremyphoward/status/1985784350412304390).\n\n[Mojo](https://mojolang.org/) is a Pythonic language that compiles directly to\nCPU or GPU code, so it explicitly only contains features that can be ported that way.\nUnfortunately, it's lower-level than I really wanted for this project (and, importantly, does not\nhave built-in autograd support).\n\nBut if it did -- if, for example, there was a library like JAX for it, perhaps it would be better than using Python as the foundation? I've looked for something like that, but to no avail. Some work-in-progress projects, but nothing ready for use.\n\nAt the end of the day, I think further experience is essential if I'm going to come to a solid\nopinion on JAX. Experience with other tools can only get you so far, and it's easy\nto fail by pattern-matching what you're looking at with things that you've seen before,\nespecially when you're old and cynical.\nAll I can say at this point is that JAX is making my \"beautiful but doomed\" spidey-sense\ntingle. [8](#fn-8)\n\n### Wrapping up\n\nThe title of this post is important -- it is my impressions on first looking into JAX, not the considered thoughts of someone who's spent months or years working with it. I've only scratched the surface, and haven't even touched the larger JAX ecosystem, or indeed its powerful handling of memory sharding for multi-GPU or even multi-node setups (which may well be one of its biggest advantages).\n\nMy next step is going to be to implement a GPT-2-style LLM in JAX, probably using Flax and Optax as helpers, and perhaps by the time I'm done with that I'll have changed my views.\n\nBut at this point -- after working through the tutorials and porting some toy models to get at least an initial feel for it, I've come to the conclusion that I like it. The question is, do I like it like I liked Python when I first came to it -- \"this thing is really neat and clean, even if it has flaws\" or is it more like I liked Haskell -- \"this is a stunning thing of beauty and is completely doomed in the real world\"?\n\nTime will tell. But in the meantime, if you've been working with JAX for some time and want to counter any of the points I made, if I've completely misunderstood anything, or if you have any corrections, then please let me know! After all, explorers in areas new to them are prone to making mistakes from time to time...\n\nThe forest of Skund was indeed enchanted, which was nothing unusual on the Disc, and was also the only forest in the whole universe to be called -- in the local language -- Your Finger You Fool, which was the literal meaning of the word Skund.\n\nThe reason for this is regrettably all too common. When the first explorers from the warm lands around the Circle Sea travelled into the chilly hinterland they filled in the blank spaces on their maps by grabbing the nearest native, pointing at some distant landmark, speaking very clearly in a loud voice, and writing down whatever the bemused man told them. Thus were immortalised in generations of atlases such geographical oddities as Just A Mountain, I Don't Know, What? and, of course, Your Finger You Fool.\n\nRainclouds clustered around the bald heights of Mt. Oolskunrahod ('Who is this Fool who does Not Know what a Mountain is') and the Luggage settled itself more comfortably under a dripping tree, which tried unsuccessfully to strike up a conversation.\n\n-\nSpecifically, prior to the introduction of\n\n`torch.compile`\n\n-- more about that later.[↩](#fnref-1) -\nThat's something I find myself constantly forgetting; I'll talk about \"the loss landscape\" as if it's something our training loop is exploring. And, of course, there is an overall loss landscape across all of the training data as a whole, but in any given iteration through the training loop, the loss is relative to the specific batch we're looking at.\n\n[↩](#fnref-2) -\nYou can also pass in an\n\n`argnums`\n\nargument, zero by default, to tell it to do the derivative with respect to a different parameter or with respect to a sequence of parameter indexes. If you give a sequence, it will return a tuple of gradients. Additionally, there's a`value_and_grad`\n\nthat returns a tuple of the value of`f`\n\nand the gradients, which is useful for tracking loss as you train -- we'll use that later on.[↩](#fnref-3) -\nYou can also make classes \"PyTree-compatible\" by providing helper functions that map to and from that representation.\n\n[↩](#fnref-4) -\nA reminder if your memory of Python decorator syntax is rusty -- this:\n\n``` python\n@a_decorator\ndef some_function(x):\n    ...\n```\n\n...is just syntactic sugar for this:\n\n``` python\ndef some_function(x):\n    ...\n\nsome_function = a_decorator(some_function)\n```\n\n-\nIt's a tad more complicated than that -- the metadata for array traces also contains the shape. More about that later.\n\n[↩](#fnref-6) -\nFor the pedantic: over ten runs of each, the numbers were pretty stable.\n\n[↩](#fnref-7) -\nIn case you're thinking that JAX is backed by Google and guaranteed to thrive because of that, remember\n\n[Ada](https://en.wikipedia.org/wiki/Ada_(programming_language)). Backed by the US Department of Defense. For its time, well-designed and elegant. It's still used, but it's hardly mainstream... I remember reading about it in*Byte*magazine back in 1988 or so, and had an \"it's so beautiful\" moment then too. To be fair to me, I was 14.[↩](#fnref-8)", "url": "https://wpnews.pro/news/first-looking-into-jax", "canonical_source": "https://www.gilesthomas.com/2026/05/on-first-looking-into-jax", "published_at": "2026-05-30 18:01:44+00:00", "updated_at": "2026-05-30 18:18:22.075183+00:00", "lang": "en", "topics": ["machine-learning", "large-language-models", "neural-networks", "artificial-intelligence", "ai-research"], "entities": ["PyTorch", "Sebastian Raschka", "Andrej Karpathy", "John Keats"], "alternates": {"html": "https://wpnews.pro/news/first-looking-into-jax", "markdown": "https://wpnews.pro/news/first-looking-into-jax.md", "text": "https://wpnews.pro/news/first-looking-into-jax.txt", "jsonld": "https://wpnews.pro/news/first-looking-into-jax.jsonld"}}