First Looking into Jax A developer who has spent years working extensively with PyTorch has begun exploring JAX, a competing machine learning framework, and reports finding it cleaner and more mathematically pure. The author contrasts JAX's functional, just-in-time compiled approach with PyTorch's procedural, piecewise-optimized design, arguing the difference reflects a fundamental divide between mathematical and engineering philosophies in deep learning tooling. On first looking into JAX Much have I travell'd in the realms of gold, And many goodly states and kingdoms seen; Round many western islands have I been Which bards in fealty to Apollo hold. Oft of one wide expanse had I been told That deep-brow'd Homer ruled as his demesne; Yet did I never breathe its pure serene Till I heard Chapman speak out loud and bold: Then felt I like some watcher of the skies When a new planet swims into his ken; Or like stout Cortez when with eagle eyes He star'd at the Pacific -- and all his men Look'd at each other with a wild surmise -- Silent, upon a peak in Darien. John Keats,On First Looking into Chapman's Homer I've been working with PyTorch https://pytorch.org/ quite a lot for the last couple of years, and feel like I've come to a reasonably solid understanding of how it all fits together. Working through /llm-from-scratch Sebastian Raschka https://sebastianraschka.com/ 's book " Build a Large Language Model from Scratch https://www.manning.com/books/build-a-large-language-model-from-scratch ", training my own LLMs locally /2025/12/llm-from-scratch-28-training-a-base-model-from-scratch and in the cloud /2026/04/llm-from-scratch-32j-interventions-trying-to-train-a-better-model-in-the-cloud , rebuilding Andrej Karpathy's 2015-vintage RNNs /2025/10/retro-language-models-rebuilding-karpathys-rnn-in-pytorch -- over time, it all adds up But, of course, there are other frameworks, and one I kept hearing about was JAX https://docs.jax.dev/en/latest/index.html . While it's less dominant than PyTorch, it has a reputation for a certain cleanliness, a certain purity. And having spent time over the last couple of weeks working through the tutorials, and translating small PyTorch examples into it, I've been really impressed. In this post I want to give an overview -- to report back to beginners like me, still living in PyTorch-land, on my new discovery. Less like Herschel discovering Uranus, and more like a 16th-century European coming back after having discovered something that the people who lived there were perfectly well aware of. What is this JAX thing, and how does it differ from PyTorch? Some theses, significantly overstated I think that the main differences between PyTorch and JAX are something like this, but a little less strident: - PyTorch is engineering; JAX is maths. - PyTorch has historically been optimised piecewise, JAX is JITted. 1 fn-1 - PyTorch is procedural, JAX tries to be functional. - PyTorch is maximalist; JAX is minimalist. Having overstated my claims, let me dig in and perhaps walk them back a bit. Once I've gone through them, I'll do a walkthrough of porting a simple PyTorch training loop to JAX, which should illustrate the points well. Finally, I'll wrap up with the counterargument. JAX is wonderful and shiny, and 30+ years of industry experience and cynicism makes me fear that it might be doomed :- But let's start with the positive Happy face on. 1. Maths versus engineering A simple example that nicely contrasts the different philosophies of the two frameworks is what the core of a training loop looks like. Here's how you might write one in PyTorch: optimizer.zero grad result = model inputs loss = loss function result, targets loss.backward optimizer.step This is kind of mechanistic. You're telling the computer what to do, step by step: - Zero out the gradients that you currently have attached to the parameters. - Do a forward pass to get the model's outputs. - Work out the loss based on those outputs. - Do the backward pass. - Update the parameters based on the gradients that the backward pass attached to them. Now let's look at a parallel JAX implementation: python def calculate loss parameters, inputs, targets : result = forward parameters, inputs return loss function result, targets ... def train : ... grads = jax.grad calculate loss layers, inputs, targets layers = step layers, grads, learning rate It's clearly very different. No explicit backward pass, no gradient-zeroing, and the forward pass and loss calculation are baked into a separate function. But why is it shaped that way? Let's think about what we're actually doing in our training loop. The gradients are the partial derivative of the loss function against the weights : Now, I'm being a bit sloppy with that notation, because is a function, and it -- in the mathematical formulation -- takes the weights as a parameter. So it would be better written like this: But that's still not quite right. In a real training loop, we're doing this in the context of a particular input batch, , and its associated targets, . 2 We might write that mathematically as this: ...where you can read the colon as "given". Now let's look again at the JAX code to work out the gradients: grads = jax.grad calculate loss layers, inputs, targets That's an almost-perfect mirror of the maths The jax.grad function takes a function f , and returns another function, g , which takes the same arguments. When you call g , instead of returning the result of f , it will return the derivative of f with respect to its first argument, given the values of the others. 3 fn-3 How is it doing that magic? Let's look at a simple concrete example: python def f x, y : print f"In the function {x=}, {y=}" return x + y If you do the initial call to grad : g = jax.grad f ...then it just wraps f in a helper function. It's when you call g that the magic happens. g 2.0, 1.0 ...will print out this: In the function x=GradTracer primal=2.0, typeof tangent =f32 , y=1.0 The first parameter -- the one with respect to which we're asking for the derivative -- is replaced by a GradTracer object. Because it's wrapping a float, it can be used like one, so the function executes as expected. But it also keeps track of what happens to this variable as the code executes, and essentially builds up what in PyTorch would be represented by the computation graph. So: while in PyTorch, the variables that you pass in to a function that you need gradients for need to be special PyTorch objects that can keep a reference to those gradients -- the requires grad parameter that pops up frequently in PyTorch code -- in JAX, it's all handled by variables being automatically wrapped in these special tracers. Once it has the results of the function as a whole, including the chain of operations that was traced, it can automatically do a backward pass, and we're done. That's really nifty Now, the example above was a toy one, with just one parameter. In a real training loop, you're differentiating against a set of weights, and those will be something more complex. But grad handles that gracefully. Let's see what happens if we pass in an array as the first parameter: python import jax import jax.numpy as jnp def f x, y : ... print f"In the function {x=}, {y=}" ... return x + y .sum ... g = jax.grad f g jnp.array 1., 2., 3. , jnp.array 4., 5., 6. In the function x=GradTracer primal= 1. 2. 3. , typeof tangent =f32 3 , y=Array 4., 5., 6. , dtype=float32 Array 1., 1., 1. , dtype=float32 So, we've got partial derivatives with respect to the elements of the array that was the first parameter -- just what we'd need for a single-layer neural network without bias. But what about something more complicated? For something like say an LLM, we have quite a lot of structure to our weights: our input embeddings, output head, all of the layers with their attention and feed-forward weights, and so on. grad handles that by understanding basic Python structures -- things that can be mapped to what JAX calls PyTrees. PyTrees are nested tree structures of dictionaries, lists, tuples and so on, where the leaves are numbers or JAX arrays 4. If you ask for gradients of a variable that can be represented by a PyTree, you get them back in a form that mirrors that PyTree: python def f x, y : ... print f"In the function {x=}, {y=}" ... return x "a" "b" + y .sum ... g = jax.grad f g {"a": {"b": jnp.array 1., 2., 3. }}, jnp.array 4., 5., 6. In the function x={'a': {'b': GradTracer primal= 1. 2. 3. , typeof tangent =f32 3 }}, y=Array 4., 5., 6. , dtype=float32 {'a': {'b': Array 1., 1., 1. , dtype=float32 }} If you combine that with JAX's tree-aware map function, you can combine those gradients with the original parameters to update them as you train. I'll show you how that works later on, when we go through an example of porting some PyTorch code to JAX. So, all of that cool stuff was made possible by the tracer objects, which are passed in instead of the real parameters, and keep track of the computation graph just like the graph that PyTorch attaches directly to the variables . But tracers are more generally useful than that; they really come into their own with the next JAX difference: the JIT. 2. JIT vs piecewise optimisation Imagine that you've built some kind of nifty model in PyTorch. As part of it, you do a calculation something like this: You decide that this is generally useful, so you code it up as a CUDA kernel https://huggingface.co/kernels/erikkaum/maxsim and make it available to the community, like Erik Kaunismäki has with his "MaxSim" kernel. Maybe later on, it will get added to the PyTorch library as a standard component. There are a lot of optimisations like that built into PyTorch; people found that there were higher-level abstractions on top of basic tensor operations that were generally useful, so they coded up lower-level optimised versions. For example, in the LLM I've been working with, there is an implementation of LayerNorm /2025/07/llm-from-scratch-16-layer-normalisation . But PyTorch has its own one built in https://docs.pytorch.org/docs/2.12/generated/torch.nn.LayerNorm.html . And there's a CUDA implementation https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/layer norm kernel.cu that it will use automatically if it has the appropriate hardware available. There is a problem, though. Imagine that someone else is working on a different kind of model in the future. And for reasons completely unrelated to the MaxSim calculations that Kaunismäki nicely optimised, they happen to need to do the same calculations. Now, there are two things that can happen from there: - They don't know that the MaxSim kernel exists, so their code remains unoptimised. - They do know that it exists, so they repurpose it for whatever their use case is. The first is not ideal; but the second isn't great either, if what they're using it for is not a MaxSim operation in reality, just something that happens to look the same mathematically. In the general case: all optimisations that get into PyTorch have to be carefully named so that they reflect the exact level of abstraction that they're targeting. And when people are writing PyTorch models, they need to actually know which optimised abstractions are available, and where to apply them. Now let's look at JAX. It has an innocuous-looking decorator, jit , and you can use it by adding a single line before your function: python @jax.jit def selu x, alpha=1.67, lambda =1.05 : return lambda jnp.where x 0, x, alpha jnp.exp x - alpha Behind that single line is a huge amount of useful infrastructure. Just like grad , it's a function that takes one function and returns another, without necessarily running the underlying code. 5 But when you call the wrapped function for the first time, some impressive stuff happens: selu 1.234 This will essentially execute the selu code twice: The first time through, it will create another of those tracer objects; this time, though, it won't wrap the number 1.234 -- it will just know that it is a wrapper for a float. It will call the Python code with that tracer, and all of the operations in the function will be run, but the result that comes out at the end will essentially just be a representation of what calculations were done in an abstract sense -- like the computation graph that was used for working out gradients, but without specific numbers in it.JAX has a nice way to display these representations as what it calls JAXPRs, and the JAXPR for that function's representation when called with a float parameter will look something like this: js { lambda ; a:f32 . let b:bool = gt a 0.0:f32 c:f32 = exp a d:f32 = mul 1.67:f32 c e:f32 = sub d 1.67:f32 f:f32 = jit name= where jaxpr={ lambda ; b:bool a:f32 e:f32 . let f:f32 = select n b e a in f, } b a e g:f32 = mul 1.05:f32 f in g, } That JAXPR can be compiled into the appropriate code for the platform where you're running it -- x86 machine code, compiled CUDA, the equivalent for AMD or Google Tensor Processing Units TPUs , and will be cached. The key for the cache will be meta-information about the parameter -- in this case, something like "a 32-bit floating-point scalar". Next, the compiled code -- not the original Python -- is run with the actual value of the parameter, the 1.234 that we provided. Now, of course, the advantage of doing this is that when you call it with a different floating-point number -- say, 5.678 -- then you don't need to do the compilation again. You can just rely on the cached version. And the fact that the compiled code is cached based on the metadata means that if you call selu with a vector, then it will compile a new version for that, and likewise for a matrix version. 6 fn-6 This is all really nifty, and you can see how it would help right away. But for me, at least, an excellent extra benefit is how it can save people like Erik Kaunismäki the bother of writing custom kernels. The compilation that happens, taking the representation that it got from the tracing process and turning it into backend code, goes through an optimising compiler, XLA https://openxla.org/ . And that compiler can recognise "standard" operations and combine them together. This won't be at the level of "standard operations" like MaxSim, of course -- more, "this looks like a convolution, let's use the standard kernel". But it does mean that instead of someone having to take code written in Python and hand-port it over to CUDA to get a GPU speedup, the same expertise can be put into improving the optimisation part of XLA to get a speedup for all code. That's pretty amazing. However... 3. Procedural vs functional code If you want something like the JIT to work properly, you need to limit the kind of code that it works with. In particular, it needs to be functional. A function must always return the same value when given the same inputs -- so this is fine: python @jax.jit def add x, y : return x + y print add 1, 2 print add 1, 3 ...but this will cause problems: python @jax.jit def addY x : return x + y y = 2 print addY 1 y = 3 print addY 1 ...because y could be changed. Specifically -- because the global y had the value 2 during the initial traced run of the function, that value will essentially get hard-coded into the cached JITted version, so both prints in the second example will output 3 . Something slightly surprising comes out of this -- something that makes JAX code look very different to PyTorch. How we handle randomness needs to completely change. Consider this code: python import random def f x : return x + random.randint 1, 10 random.seed 42 print f 1 print f 1 As a whole, it's deterministic. But it breaks the functional requirement that the function can only depend on its inputs. Both calls to f take the same input, but they return different results. Even worse, if we were to do something that consumed randomness between those two calls to f , for example: print f 1 random.randint 1, 10 print f 1 ...we'd get different results. The state of the random number generator is global state kept outside the function, just like y in the addY example above. A naive solution to this might be to make the state of the RNG explicit as a variable -- you can imagine a library that worked something like this: python import updated random def f x, random state : return x + random state.randint 1, 10 random state = updated random.new state 42 print f 1, random state print f 1, random state That looks more functional, but when you think about it, we haven't actually fixed the problem. We're passing the same random state variable in in both cases, along with the same number, but we're getting different results. It's not global, but it's still mutable behind the scenes. What you'd actually need to do to make it purely functional would be something like this: python import updated random def f x, random state : new state, randint = updated random.randint random state, 1, 10 return new state, x + randint initial random state = updated random.new state 42 first call random state, result = f 1, initial random state print result second call random state, result = f 1, first call random state print result The updated random.randint function is generating a new random integer and returning both that and the new state of the RNG, then we pass that back along with our result. We've made the random state variables immutable, and so it's functional. But the API is getting pretty ugly pretty quickly. So JAX does something that is equivalent, but a bit cleaner. There's a concept of a key , which needs to be passed into any function that consumes randomness: key = jax.random.key 42 That's kind of like the random state that we have in the first version of the code above. But it's immutable; when you use it, like this: jax.random.randint key, , 1, 11 ...it will not be changed, so no matter how many times you call it with the same key, that function will return the same value. Note that jax.random.randint takes an inclusive lower bound and an exclusive upper bound, like Python's range , but unlike the stdlib's random.randint . It also needs to know the shape of the result -- for a scalar, 1, 2 for a 1x2 array, and so on. If you want it to "move on" to a new state, you use the split function, which takes an existing key and returns two or more new ones. So you can do something like this: python import jax.random def f x, key : return x + jax.random.randint key, , 1, 11 initial key = jax.random.key 42 first call key, new key = jax.random.split initial key print f 1, first call key second call key, new new key = jax.random.split new key print f 1, second call key Now, that new key and new new key stuff is a bit ugly, but while it's not OK to mutate the contents of variables in functional code, it's absolutely fine to assign a new value to an existing one, so what I've found myself doing is writing stuff like this: python import jax.random def f x, key : return x + jax.random.randint key, , 1, 11 key = jax.random.key 42 first call key, key = jax.random.split key print f 1, first call key second call key, key = jax.random.split key print f 1, second call key However, there are more powerful ways to use split ; I'm not confident enough at using it yet to go into that, though, so I'll hold back for now. I suspect assuming I keep using JAX I'll be posting about them in the future. OK: so the JIT means that we have to write functional code, which makes things a bit fiddly -- no more global state. And that has a surprisingly big knock-on effect with randomness. But there's another thing that comes out of the JIT and the way it does tracing. It's not a functional thing though some of the docs seem to almost be treating it that way , but is caused by the same kind of constraints. It's not part of my four theses above, but I think it's important enough to call out in its own subsection. 3.5. Control flow and values Imagine this function: python @jax.jit def f x : if x 2: return x 2 return x print f 10.0 It's purely functional, so no problem there. But let's think about what the JIT is trying to do. It wants to convert the function into a simple sequence of operations, so it will create a tracer for a floating-point scalar, then call f with it. When it hits that if statement, there will be a problem. The tracer is meant to represent any arbitrary float, so should it take the if branch or not? There's no good answer. It doesn't know which branch to follow -- whether the sequence should be "square it and return the result" or just "return it directly" -- and will fail with a somewhat obscure error message: jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool . So this gives a hard constraint on functions that you want to JIT: by default, they can't base control flow on the values you pass in. There is a workaround -- but it comes with tradeoffs. Let's take a slightly sideways route to explain it. Firstly, although you cannot do control flow based on the value of a parameter -- which the tracer doesn't know -- you can base it on other information that actually is stored in the tracer. Let's say that we called f like this: f jax.numpy.array 1., 2. , 3., 4. The tracer that would be passed in when trying to trace the function would be something representing a 2x2 array. The shape of the parameter is part of the tracer, even though the values aren't. So you could do something like this: python @jax.jit def f x : if len x.shape 1: return x 2 return x ...and it would work. It's worth thinking explicitly why this is. When you call a JITted function, it will create a tracer that contains information about the type of thing you passed in as a parameter -- scalar versus array, and if it's an array, the array's shape. It then runs the function with the tracer, gets the sequence of operations, compiles them and then stores the result in a cache keyed on the metadata -- type and, if appropriate, shape -- that it used to create the tracer. So when we call that function with a 2x2 array, we get a 2x2 array version, then if we call it later with a one-dimensional array of length 2, we'll get a new version for that. One workaround for basing control flow on values is essentially to tell the jit function that it should treat the values of a particular variable as being like the metadata used for this cache keying: it should compile a new version for each value it sees, rather than just using the metadata. It takes a parameter static argnums , and a matching static argnames , which tell it which parameters to do that with. So, this will work: python from functools import partial @partial jax.jit, static argnums= 0, def f x : if x 2: return x 2 return x print f 10.0 Remember that the thing after the @ for a decorator needs to be a function that returns a function, so we have to use partial to "inject" in the extra argument. However, the downside is pretty clear: every time we call f with a new value, it's going to have to JIT a new version of the function and cache it -- that's going to be slow and take up memory. So, as an alternative, we can use the jax.lax package https://docs.jax.dev/en/latest/jax.lax.html control-flow-operators . This provides more functional-looking alternatives for control flow, which are compatible with the way the JIT works. For example, there's a cond function, which we can use to replace if s: python @jax.jit def f x : return jax.lax.cond x 2, lambda: x 2, lambda: x print f 10.0 That feels a little bit like a workaround, but it does solve the problem. How? Well, it's worth checking the JAXPR for it: js jax.make jaxpr f 10.0 { lambda ; a:f32 . let b:f32 = jit name=f jaxpr={ lambda ; a:f32 . let c:bool = gt a 2.0:f32 d:i32 = convert element type new dtype=int32 weak type=False c b:f32 = cond branches= { lambda ; e:f32 . let in e, } { lambda ; f:f32 . let g:f32 = integer pow y=2 f in g, } d a in b, } a in b, } What's happened here, I think, is that the JIT has recognised the call to jax.lax.cond as being a primitive function in its intermediate language, so has just kept it in there. It couldn't do that with the if because when it was tracing, all JAX itself saw was what was happening to the tracer -- there was a boolean comparison, and then the stuff in the chosen branch happened. The fact that there was an if there happened in Python itself, outside JAX, so it was "invisible" to the trace. That feels a little inelegant to me right now, and I'll come back to it later. Let's move on to the final difference between the two libraries that I want to cover: JAX's relative minimalism to PyTorch's more maximalist approach. 4. Minimalism versus maximalism I think the smaller size of JAX -- at least in terms of its API, if not in terms of the JIT and XLA magic under the hood -- compared to the sprawl of PyTorch is not entirely unrelated to the JIT being at its core. PyTorch, after some initial design, has almost been forced to grow organically; JAX feels more carefully designed, so it doesn't have the same need to grow though of course it can . The reason for PyTorch's growth is, at least in part, because it needs to absorb optimisations. If something is slow, someone needs to write a CUDA kernel for it. If there's a CUDA kernel, it needs an API. And if it is generally useful, that API becomes part of PyTorch. Multi-head attention? There's a class for that https://docs.pytorch.org/docs/2.12/generated/torch.nn.MultiheadAttention.html . SELU? Yup https://docs.pytorch.org/docs/2.12/generated/torch.nn.SELU.html . Very specific softmax approximations based on a paper published in 2016? PyTorch has you covered https://docs.pytorch.org/docs/2.12/generated/torch.nn.AdaptiveLogSoftmaxWithLoss.html . By contrast, JAX doesn't even have linear layers or optimisers in the framework itself; if you want to use them, you can write them yourself contraindicated , or you can use libraries built on top of JAX https://docs.jax.dev/en/latest/ ecosystem , like Flax https://flax.readthedocs.io/en/stable/ for common neural network components and Optax https://optax.readthedocs.io/en/latest/ for optimisers. This feels like a nice division of responsibilities, and it also seems like something that would have been very hard without the JIT. So while the JAX core may well grow in the future, the design it has now puts it in a good position to grow in a more planned, well-designed manner -- rather than having to grow to absorb more and more abstractions just to keep it fast. Those abstractions can more easily sit in libraries written on top of JAX. Porting a toy PyTorch model to JAX That's the 10,000-foot overview; four or maybe four and a half main differences between PyTorch and JAX. It's more maths-y, JITted, functional and minimalist. What does that actually mean when you get down to coding with it? Let's get into the weeds with an example. Let's use a really simple one: training a neural network with two inputs and one hidden layer to calculate the XOR function. The code is in this GitHub repo https://github.com/gpjt/toy-pytorch-to-jax , but I'll put the relevant bits here in this post. Firstly, an idiomatic PyTorch implementation: python import time import torch data = 0., 0. , 0 , 0., 1. , 1 , 1., 0. , 1 , 1., 1. , 0 , class XORModel torch.nn.Module : def init self : super . init self.layer1 = torch.nn.Linear 2, 2, bias=True self.layer1 activation = torch.nn.Sigmoid self.layer2 = torch.nn.Linear 2, 1, bias=True self.layer2 activation = torch.nn.Sigmoid def forward self, x : hidden = self.layer1 activation self.layer1 x output = self.layer2 activation self.layer2 hidden return output def calculate loss model, inputs, target : result = model inputs return result - target 2 .mean def main : torch.manual seed 42 model = XORModel optimizer = torch.optim.SGD model.parameters , lr=0.1 start = time.time for epoch in range 10000 : losses = for x, y in data: optimizer.zero grad loss = calculate loss model, torch.tensor x , torch.tensor y loss.backward losses.append loss.item optimizer.step if epoch % 1000 == 0: avg loss = sum losses / len losses print f"Loss at epoch {epoch}: {avg loss:.6f}" end = time.time print f"Trained in {end - start:.3f}s" print f"Loss at end: {avg loss:.6f}" model.eval with torch.no grad : for x, y in data: result = model torch.tensor x print f"{x=}: {result=}, {y=}" if name == " main ": main If we run that, it trains a solid-looking model in about four seconds on my machine: bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pytorch xor.py Loss at epoch 0: 0.279327 Loss at epoch 1000: 0.254715 Loss at epoch 2000: 0.254279 Loss at epoch 3000: 0.253985 Loss at epoch 4000: 0.253649 Loss at epoch 5000: 0.251566 Loss at epoch 6000: 0.189219 Loss at epoch 7000: 0.030093 Loss at epoch 8000: 0.006666 Loss at epoch 9000: 0.003516 Trained in 4.154s Loss at end: 0.003516 x= 0.0, 0.0 : result=tensor 0.0483 , y= 0 x= 0.0, 1.0 : result=tensor 0.9567 , y= 1 x= 1.0, 0.0 : result=tensor 0.9425 , y= 1 x= 1.0, 1.0 : result=tensor 0.0434 , y= 0 Now, if we're porting to JAX we need to do something about the fact that JAX doesn't have optimisers and the neural network stuff built in. If this was a real codebase, we'd almost certainly do that by using the libraries built on top of JAX, like Flax and Optax. But for this toy example, I think it's more illustrative to strip down the PyTorch version so that it uses fewer parts of the API -- essentially so that it only uses the stuff that JAX has -- and then to port the result. The optimiser first. The code is here https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch xor no optimizer.py but the diffs are pretty simple. Instead of creating an optimiser, we just specify our learning rate: < optimizer = torch.optim.SGD model.parameters , lr=0.1 --- learning rate = 0.1 Instead of zeroing out the gradients using the optimiser, we can just ask the model to do it: < optimizer.zero grad --- model.zero grad And instead of stepping the optimiser, we call a new step function passing in the model and the learning rate: < optimizer.step --- step model, learning rate The step function is simple enough; we just switch into no grad mode so that PyTorch doesn't try to track the computation graph working out gradients for applying gradients and triggering some kind of crazy gradient-ception , then we just iterate over the model's parameters and follow the normal SGD process, subtracting the gradients times the learning rate: python def step model, learning rate : with torch.no grad : for p in model.parameters : if p.grad is not None: p -= p.grad learning rate Running that on my machine actually works out slightly faster than the original 7 bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pytorch xor no optimizer.py Loss at epoch 0: 0.279327 Loss at epoch 1000: 0.254715 Loss at epoch 2000: 0.254279 Loss at epoch 3000: 0.253985 Loss at epoch 4000: 0.253649 Loss at epoch 5000: 0.251566 Loss at epoch 6000: 0.189219 Loss at epoch 7000: 0.030091 Loss at epoch 8000: 0.006665 Loss at epoch 9000: 0.003516 Trained in 3.806s Loss at end: 0.003516 x= 0.0, 0.0 : result=tensor 0.0483 , y= 0 x= 0.0, 1.0 : result=tensor 0.9567 , y= 1 x= 1.0, 0.0 : result=tensor 0.9425 , y= 1 x= 1.0, 1.0 : result=tensor 0.0434 , y= 0 It's also quite nice to see that within the bounds of the printing precision the loss and the final results are identical. OK, so now that we've got rid of the optimiser, let's do the same with the nn.Linear s. Here's the code https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch xor no nn helpers.py , but let's do a quick walk through the differences. Instead of creating an XORModel , we will just generate an array of layers: < model = XORModel --- layers = generate layer parameters 2, 2 , generate layer parameters 2, 1 , Zeroing out the existing gradients will also need to be done on those layers: < model.zero grad --- zero grad layers ...and likewise our loss calculations and the step function will need to use them: < loss = calculate loss model, torch.tensor x , torch.tensor y --- loss = calculate loss layers, torch.tensor x , torch.tensor y 58c76 < step model, learning rate --- step layers, learning rate We used a couple of new helper functions there; this one generates the initial weights for the layers based on the docs for torch.nn.Linear https://docs.pytorch.org/docs/2.12/generated/torch.nn.Linear.html : python def generate layer parameters d in, d out : root k = math.sqrt 1. / d in weights = torch.rand d out, d in 2 root k - root k biases = torch.rand d out 2 root k - root k return { "weights": weights.requires grad , "biases": biases.requires grad , } Note that each of the tensors we created, the weights and the biases need to be explicitly told, using requires grad , that we're going to want PyTorch to track gradients on them. Zeroing out the gradients is just a case of chugging through each layer, and then for each setting the weights' and the biases' gradients to None : python def zero grad layers : for layer in layers: for p in layer "weights" , layer "biases" : p.grad = None Now, to calculate the loss, we're actually not changing much. We had this: python def calculate loss model, inputs, target : result = model inputs return result - target 2 .mean ...and now we just change it to this: python def calculate loss layers, inputs, target : result = forward layers, inputs return result - target 2 .mean That is, we've added on a new function forward to do a forward pass through the given layers with the given parameters. That looks like this: python def forward layers, inputs : x = inputs for layer in layers: x = torch.sigmoid x @ layer "weights" .T + layer "biases" return x A quick tweak to use forward in the printing of the results at the end: < result = model torch.tensor x --- result = forward layers, torch.tensor x ...and we're done Let's run it: bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pytorch xor no nn helpers.py Loss at epoch 0: 0.279327 Loss at epoch 1000: 0.254715 Loss at epoch 2000: 0.254279 Loss at epoch 3000: 0.253985 Loss at epoch 4000: 0.253649 Loss at epoch 5000: 0.251566 Loss at epoch 6000: 0.189218 Loss at epoch 7000: 0.030092 Loss at epoch 8000: 0.006665 Loss at epoch 9000: 0.003516 Trained in 3.504s Loss at end: 0.003516 x= 0.0, 0.0 : result=tensor 0.0483 , y= 0 x= 0.0, 1.0 : result=tensor 0.9567 , y= 1 x= 1.0, 0.0 : result=tensor 0.9425 , y= 1 x= 1.0, 1.0 : result=tensor 0.0434 , y= 0 Even faster Sounds like there aren't any nice pre-baked optimisations in that part of PyTorch, then... But again, within the bounds of our precision, that's exactly the same numbers as we got from the original PyTorch version, which is very reassuring. OK, now that we've got something that's kind of JAX-shaped, let's port it over. I think it's worth showing all of the code for that though it's here on GitHub https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pure jax xor no jit.py if you want to view it there , and then I'll highlight the important diffs separately. python import math import time import jax import jax.numpy as jnp jax.config.update "jax platform name", "cpu" data = 0., 0. , 0 , 0., 1. , 1 , 1., 0. , 1 , 1., 1. , 0 , def generate layer parameters key, d in, d out : weight key, bias key = jax.random.split key root k = math.sqrt 1. / d in weights = jax.random.uniform weight key, shape= d out, d in 2 root k - root k biases = jax.random.uniform bias key, shape= d out, 2 root k - root k return { "weights": weights, "biases": biases, } def forward layers, inputs : x = inputs for layer in layers: x = jax.nn.sigmoid x @ layer "weights" .T + layer "biases" return x def step layers, grads, learning rate : layers = jax.tree.map lambda p, g: p - g learning rate, layers, grads, return layers def calculate loss layers, inputs, target : result = forward layers, inputs return result - target 2 .mean def main : key = jax.random.key 42 layer 1 key, layer 2 key = jax.random.split key layers = generate layer parameters layer 1 key, 2, 2 , generate layer parameters layer 2 key, 2, 1 , learning rate = 0.1 start = time.time for epoch in range 10000 : losses = for x, y in data: loss, grads = jax.value and grad calculate loss layers, jnp.array x , jnp.array y losses.append loss.item layers = step layers, grads, learning rate if epoch % 1000 == 0: avg loss = sum losses / len losses print f"Loss at epoch {epoch}: {avg loss:.6f}" end = time.time print f"Trained in {end - start:.3f}s" print f"Loss at end: {avg loss:.6f}" for x, y in data: result = forward layers, jnp.array x print f"{x=}: {result=}, {y=}" if name == " main ": main If you look at it side-by-side with the previous PyTorch implementation https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch xor no nn helpers.py , you'll see that it's really similar Running diff between them makes them look more different than they are because of the extra threading through of keys that we need to do in order to satisfy the strict constraints on random number handling in JAX, and of course there are function name changes like torch.rand becoming jax.random.uniform and torch.sigmoid becoming jax.nn.sigmoid . But the important changes are much smaller. Firstly, weights and biases no longer need to know that we'll want to track gradients for them, because that's all handled by the tracers that JAX wraps around them: < "weights": weights.requires grad , < "biases": biases.requires grad , --- "weights": weights, "biases": biases, Relatedly, the zero grad function that iterated over the layers and zeroed out the existing ones is completely gone. Because gradients are now stored on tracers that wrap around our parameters rather than on the parameters themselves, we don't need to zero them out. The step function is still there, though, but it's much simpler. Before we get to that, let's take a look at the way we're getting the gradients for it, in the main training loop. Here's the diff: < loss = calculate loss layers, torch.tensor x , torch.tensor y < loss.backward --- loss, grads = jax.value and grad calculate loss layers, jnp.array x , jnp.array y Hopefully the change there will be nice and familiar from the start of this post: we've moved from the PyTorch procedural "do a forward pass then do the backward pass" to the JAX maths-y "work out the gradients for this function". value and grad is a utility function that does the same as the grad we encountered then, but rather than just returning the gradients, it also returns the value of calculate loss with the given parameters, which is useful for our logging. Now, remember that layers is a list of dictionaries, something like this: { 'biases': Array -0.11810607, -0.58481467 , dtype=float32 , 'weights': Array -0.37359995, 0.6218162 , -0.4298191 , 0.15088385 , dtype=float32 }, { 'biases': Array -0.49658495 , dtype=float32 , 'weights': Array -0.38409787, 0.6165393 , dtype=float32 } And also remember that grad -- and likewise value and grad -- have that smart trick where they return the gradients in the same PyTree structure as the parameter that we're taking the derivative with respect to. So grads will also be a list of dictionaries, each of which has weights and biases . Now, as I mentioned earlier, JAX has a useful function called jax.tree.map . Like the Python map https://docs.python.org/3/library/functions.html map function that maps a function over one or more lists, JAX's version maps a function over one or more things with the same PyTree structure. So, because layers and grads have the same structure, our step function can just use it to apply simple gradient descent like this: python def step layers, grads, learning rate : layers = jax.tree.map lambda p, g: p - g learning rate, layers, grads, return layers Very clean :- That's it A full JAX implementation of our toy example, and when we run it: bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pure jax xor no jit.py An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. Loss at epoch 0: 0.267455 Loss at epoch 1000: 0.247348 Loss at epoch 2000: 0.061305 Loss at epoch 3000: 0.008652 Loss at epoch 4000: 0.004108 Loss at epoch 5000: 0.002627 Loss at epoch 6000: 0.001912 Loss at epoch 7000: 0.001496 Loss at epoch 8000: 0.001224 Loss at epoch 9000: 0.001034 Trained in 104.540s Loss at end: 0.001034 x= 0.0, 0.0 : result=Array 0.03008602 , dtype=float32 , y= 0 x= 0.0, 1.0 : result=Array 0.97214633 , dtype=float32 , y= 1 x= 1.0, 0.0 : result=Array 0.96557194 , dtype=float32 , y= 1 x= 1.0, 1.0 : result=Array 0.02664344 , dtype=float32 , y= 0 ...it works So, let's move on to... Hang on: Trained in 104.540s Yikes. It was almost 30 times slower than the PyTorch version. But then -- we did all of that work to port the code over to JAX, which is great because it has a JIT, and then we didn't use the JIT. Whoops Adding a few calls to @jax.jit helps. If we add them to the forward , step and calculate loss function then we get this code https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pure jax xor initial jit.py , which is faster: bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pure jax xor initial jit.py An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. Loss at epoch 0: 0.267455 Loss at epoch 1000: 0.247348 Loss at epoch 2000: 0.061305 Loss at epoch 3000: 0.008652 Loss at epoch 4000: 0.004108 Loss at epoch 5000: 0.002627 Loss at epoch 6000: 0.001912 Loss at epoch 7000: 0.001496 Loss at epoch 8000: 0.001224 Loss at epoch 9000: 0.001034 Trained in 27.663s Loss at end: 0.001034 x= 0.0, 0.0 : result=Array 0.03008603 , dtype=float32 , y= 0 x= 0.0, 1.0 : result=Array 0.97214633 , dtype=float32 , y= 1 x= 1.0, 0.0 : result=Array 0.96557194 , dtype=float32 , y= 1 x= 1.0, 1.0 : result=Array 0.02664347 , dtype=float32 , y= 0 ...but it's still almost eight times slower than the PyTorch code. How can we make it faster? Well, perhaps we can do more if we put more of the loop into the JITted stuff. Right now, the core of our training loop looks like this: for x, y in data: loss, grads = jax.value and grad calculate loss layers, jnp.array x , jnp.array y losses.append loss.item layers = step layers, grads, learning rate calculate loss and step are JITted. But what happens if we try to JIT a larger step? We can move the forward pass and the step into a JITted function on their own: python @jax.jit def train step layers, inputs, targets, learning rate : loss, grads = jax.value and grad calculate loss layers, inputs, targets layers = step layers, grads, learning rate return layers, loss ...and then call it in the loop like this: for x, y in data: layers, loss = train step layers, jnp.array x , jnp.array y , learning rate losses.append loss.item With that, all of the JAX code apart from input and target wrangling is moved into a JITted function. We get this code https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pure jax xor final jit.py , and running it gives us this: bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pure jax xor final jit.py An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. Loss at epoch 0: 0.267455 Loss at epoch 1000: 0.247348 Loss at epoch 2000: 0.061305 Loss at epoch 3000: 0.008652 Loss at epoch 4000: 0.004108 Loss at epoch 5000: 0.002627 Loss at epoch 6000: 0.001912 Loss at epoch 7000: 0.001496 Loss at epoch 8000: 0.001224 Loss at epoch 9000: 0.001034 Trained in 2.432s Loss at end: 0.001034 x= 0.0, 0.0 : result=Array 0.03008603 , dtype=float32 , y= 0 x= 0.0, 1.0 : result=Array 0.97214633 , dtype=float32 , y= 1 x= 1.0, 0.0 : result=Array 0.96557194 , dtype=float32 , y= 1 x= 1.0, 1.0 : result=Array 0.02664347 , dtype=float32 , y= 0 Woohoo Almost 45% faster than the PyTorch version :- So: porting to JAX alone gives us nice maths-y code, but we need to JIT it properly to get performance that matches PyTorch. The fact that it's faster than PyTorch in this case is not something that I think you could rely on -- this is, after all, a toy example. It's also an interesting indicator that you actually need to think about what to JIT. My initial thought, "just whack an @jit on the inner stuff", was not enough. We needed to do more than that. I've just had an interesting chat with Claude Opus 4.8 about that, though, and will probably post more about it later. For now, I think a useful rule-of-thumb is to wrap stuff in @jit at as high a level as you reasonably can, to maximise coverage. So, this completes the happy part of this post -- I've shown what it can do, how nicely it maps to the maths, and how it's relatively easy to make it fast. What are the downsides? Why JAX is doomed Another deliberately overly-strident heading ;- I've been programming for more than 40 years, and working professionally in the tech industry for more than 30. I'd like to feel that this makes me a better engineer than I was when I was first starting out, but I can confidently say that it has made me a much more cynical one. Over that period, I've come to categorise new APIs, languages, and tools into three approximate groups: godawful hacks, solid but not overly inspiring engineering, and things of beauty. They're loose categories, and most things are somewhere between one and another. But I think they hold reasonably well. My cynicism and experience tells me that: - Horrible hacks can inexplicably become popular, but normally die off when people get tired of swearing at them. Though sometimes a large installed base means that they linger. - Things of beauty get people excited, and often pull in the best engineers. But eventually, they drop by the wayside. Perhaps there's some hidden flaw that no-one noticed at the outset, or perhaps the mental model you need to build in order to use them effectively is too complicated for them to get to critical mass. - Solid, boring engineering wins in the long term. When we were building our programmable spreadsheet, Resolver One /resolver-one , some of the team pointed out that a functional language -- specifically, Haskell https://www.haskell.org/ -- would be a better fit than Python. It was a tough decision to stick with Python, and I'm still not 100% sure it was the right one. But I do remember having sales meetings with quants at various financial firms about it, and in those meetings, some of the potential customers also suggested a Haskell port. I'm not saying that there's a perfect correlation between where we heard that, and the later notes in our sales status spreadsheet saying "client being acquired by a non-bankrupt competitor, all expenditure on hold" /2008/11/do-one-thing-and-do-it-well during the 2008 financial crisis. But I'm not not saying that either. If you've read this far, you can probably tell that I see PyTorch as solid engineering, and JAX as closer to a thing of beauty. Maybe it's just the cynicism of age, but let me try to articulate the things I worry might put JAX into the "beautiful but doomed" side of the "beautiful" category. Firstly, I'm not convinced by the way that JAX, with its JIT, requires you to try to write Python as if it were a functional language. It's easy enough to see that this isn't functional: python @jax.jit def addY x : return x + y ...but harder with this: python def f x : return x + random.randint 1, 10 Even worse, the way that tracing works means that you have even more constraints than "just" being functional would require -- remember this example from earlier? python @jax.jit def f x : if x 2: return x 2 return x Python is not functional, and is deliberately so. Trying to make it so is always going to lead to weird bugs for example, how the value of the global y on the first run would be baked into that addY function and hard-to-understand error messages you really need to be clued-up to work out what Attempted boolean conversion of traced array with shape bool means . The jax.lax package -- for example, the cond function we used to work around the fact that JAX could not "see" the Python if way back in this post -- feels like a bit of an ugly workaround. Python has control flow functions, but they don't work with the JIT's tracing, so we have to re-implement them in JAX. Hmmm. Now, I've written extensively above about how JAX's restrictions, however confusing, enable a lot of the amazing stuff that wouldn't be possible in normal PyTorch. What if there were some way to write PyTorch code and compile it directly to something that can execute on the hardware? It turns out that as of 2023, there is: torch.compile https://docs.pytorch.org/tutorials/intermediate/torch compile tutorial.html . From what I understand, you're meant to be able to just attach it to your code and it gets JITted. But unlike JAX, you don't need to restrict the code you write. I've not investigated in much depth after all, this post is already absurdly long and has taken more than a month on and off to put together , but it looks like it handles stuff that can't be compiled by using a concept of a "graph break" -- that is, it happily JITs what it can, then if it hits something that it can't JIT, it will cache the "work so far" as one compiled unit, run the Python code for the unJITable stuff, then when it can drop back into JIT mode. The best of both worlds? I don't know, and would need to spend much more time investigating in order to learn. But I can say that for my minimal-effort port of my toy XOR code https://github.com/gpjt/toy-pytorch-to-jax/blob/main/pytorch xor with compile.py , following the structure of the JITted JAX version, it really did not help: bash giles@perry:~/Dev/toy-pytorch-to-jax main $ uv run pytorch xor with compile.py Loss at epoch 0: 0.279327 Loss at epoch 1000: 0.254715 Loss at epoch 2000: 0.254279 Loss at epoch 3000: 0.253985 Loss at epoch 4000: 0.253649 Loss at epoch 5000: 0.251566 Loss at epoch 6000: 0.189218 Loss at epoch 7000: 0.030091 Loss at epoch 8000: 0.006665 Loss at epoch 9000: 0.003516 Trained in 6.688s Loss at end: 0.003516 x= 0.0, 0.0 : result=tensor 0.0483 , y= 0 x= 0.0, 1.0 : result=tensor 0.9567 , y= 1 x= 1.0, 0.0 : result=tensor 0.9425 , y= 1 x= 1.0, 1.0 : result=tensor 0.0434 , y= 0 For those who are keeping track, that's slower than the uncompiled version, which came in at about 3.5s. And the issue doesn't seem to be an up-front cost of JITting that would be paid off if we ran for more epochs -- each individual "Loss at epoch XXX" print comes out slower. Again, for the sake of sanity I'm not going to dig into it further, especially given that this is a tiny toy model and probably about as far from the target use case of torch.compile as you can get. But it's something well worth noting for the future. Stepping back: one other way of looking at this is that Python might just be the wrong language to try to build code that compiles to GPUs. I'm learning JAX right now so that I can re-implement my existing LLM from scratch /llm-from-scratch project in something other than PyTorch, to make sure that I really understand it. I asked people on X/Twitter for votes or ideas https://x.com/gpjt/status/1985434030880293004 , and while JAX won, Jeremy Howard suggested Mojo https://x.com/jeremyphoward/status/1985784350412304390 . Mojo https://mojolang.org/ is a Pythonic language that compiles directly to CPU or GPU code, so it explicitly only contains features that can be ported that way. Unfortunately, it's lower-level than I really wanted for this project and, importantly, does not have built-in autograd support . But if it did -- if, for example, there was a library like JAX for it, perhaps it would be better than using Python as the foundation? I've looked for something like that, but to no avail. Some work-in-progress projects, but nothing ready for use. At the end of the day, I think further experience is essential if I'm going to come to a solid opinion on JAX. Experience with other tools can only get you so far, and it's easy to fail by pattern-matching what you're looking at with things that you've seen before, especially when you're old and cynical. All I can say at this point is that JAX is making my "beautiful but doomed" spidey-sense tingle. 8 fn-8 Wrapping up The title of this post is important -- it is my impressions on first looking into JAX, not the considered thoughts of someone who's spent months or years working with it. I've only scratched the surface, and haven't even touched the larger JAX ecosystem, or indeed its powerful handling of memory sharding for multi-GPU or even multi-node setups which may well be one of its biggest advantages . My next step is going to be to implement a GPT-2-style LLM in JAX, probably using Flax and Optax as helpers, and perhaps by the time I'm done with that I'll have changed my views. But at this point -- after working through the tutorials and porting some toy models to get at least an initial feel for it, I've come to the conclusion that I like it. The question is, do I like it like I liked Python when I first came to it -- "this thing is really neat and clean, even if it has flaws" or is it more like I liked Haskell -- "this is a stunning thing of beauty and is completely doomed in the real world"? Time will tell. But in the meantime, if you've been working with JAX for some time and want to counter any of the points I made, if I've completely misunderstood anything, or if you have any corrections, then please let me know After all, explorers in areas new to them are prone to making mistakes from time to time... The forest of Skund was indeed enchanted, which was nothing unusual on the Disc, and was also the only forest in the whole universe to be called -- in the local language -- Your Finger You Fool, which was the literal meaning of the word Skund. The reason for this is regrettably all too common. When the first explorers from the warm lands around the Circle Sea travelled into the chilly hinterland they filled in the blank spaces on their maps by grabbing the nearest native, pointing at some distant landmark, speaking very clearly in a loud voice, and writing down whatever the bemused man told them. Thus were immortalised in generations of atlases such geographical oddities as Just A Mountain, I Don't Know, What? and, of course, Your Finger You Fool. Rainclouds clustered around the bald heights of Mt. Oolskunrahod 'Who is this Fool who does Not Know what a Mountain is' and the Luggage settled itself more comfortably under a dripping tree, which tried unsuccessfully to strike up a conversation. - Specifically, prior to the introduction of torch.compile -- more about that later. ↩ fnref-1 - That's something I find myself constantly forgetting; I'll talk about "the loss landscape" as if it's something our training loop is exploring. And, of course, there is an overall loss landscape across all of the training data as a whole, but in any given iteration through the training loop, the loss is relative to the specific batch we're looking at. ↩ fnref-2 - You can also pass in an argnums argument, zero by default, to tell it to do the derivative with respect to a different parameter or with respect to a sequence of parameter indexes. If you give a sequence, it will return a tuple of gradients. Additionally, there's a value and grad that returns a tuple of the value of f and the gradients, which is useful for tracking loss as you train -- we'll use that later on. ↩ fnref-3 - You can also make classes "PyTree-compatible" by providing helper functions that map to and from that representation. ↩ fnref-4 - A reminder if your memory of Python decorator syntax is rusty -- this: python @a decorator def some function x : ... ...is just syntactic sugar for this: python def some function x : ... some function = a decorator some function - It's a tad more complicated than that -- the metadata for array traces also contains the shape. More about that later. ↩ fnref-6 - For the pedantic: over ten runs of each, the numbers were pretty stable. ↩ fnref-7 - In case you're thinking that JAX is backed by Google and guaranteed to thrive because of that, remember Ada https://en.wikipedia.org/wiki/Ada programming language . Backed by the US Department of Defense. For its time, well-designed and elegant. It's still used, but it's hardly mainstream... I remember reading about it in Byte magazine back in 1988 or so, and had an "it's so beautiful" moment then too. To be fair to me, I was 14. ↩ fnref-8