Mamba Explained Researchers Albert Gu and Tri Dao introduced Mamba, a State Space Model (SSM) that rivals Transformer performance while overcoming the quadratic bottleneck in attention mechanisms, enabling efficient processing of sequences up to one million tokens. Mamba achieves up to 5x faster inference than Transformers and outperforms models twice its size in language modeling, marking a significant shift in AI architecture for long-context tasks. The State Space Model taking on Transformers Right now, AI is eating the world. And by AI, I mean Transformers. Practically all the big breakthroughs in AI over the last few years are due to Transformers. Mamba, however, is one of an alternative class of models called State Space Models SSMs . Importantly, for the first time, Mamba promises similar performance and crucially similar scaling laws as the Transformer whilst being feasible at long sequence lengths say 1 million tokens . To achieve this long context, the Mamba authors remove the “quadratic bottleneck” in the Attention Mechanism. Mamba also runs fast - like “up to 5x faster than Transformer fast”1. Gu and Dao, the Mamba authors write: Mamba enjoys fast inference and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modelling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation. Here we’ll discuss: - The advantages and disadvantages of Mamba 🐍 vs Transformers 🤖 , - Analogies and intuitions for thinking about Mamba, and - What Mamba means for Interpretability, AI Safety and Applications. Problems with Transformers - Maybe Attention Isn’t All You Need We’re very much in the Transformer-era of history. ML used to be about detecting cats and dogs. Now, with Transformers, we’re generating human-like poetry, coding better than the median competitive programmer, and solving the protein folding problem. But Transformers have one core problem. In a transformer, every token can look back at every previous token when making predictions. For this lookback, we cache detailed information about each token in the so-called KV cache. This pairwise communication means a forward pass is O n² time complexity in training the dreaded quadratic bottleneck , and each new token generated autoregressively takes O n time. In other words, as the context size increases, the model gets slower. To add insult to injury, storing this key-value KV cache requires O n space. Consequently, the dreaded CUDA out-of-memory OOM error becomes a significant threat as the memory footprint expands. If space were the only concern, we might consider adding more GPUs; however, with latency increasing quadratically, simply adding more compute might not be a viable solution. On the margin, we can mitigate the quadratic bottleneck with techniques like Sliding Window Attention or clever CUDA optimisations like FlashAttention. But ultimately, for super long context windows like a chatbot which remembers every conversation you’ve shared , we need a different approach. Foundation Model Backbones Fundamentally, all good ML architecture backbones have components for two important operations: - Communication between tokens - Computation within a token In transformers, this is Attention communication and MLPs computation . We improve transformers by optimising these two operations2. We would like to substitute the Attention component3 with an alternative mechanism for facilitating inter-token communication. Specifically, Mamba employs a Control Theory-inspired State Space Model, or SSM, for Communication purposes while retaining Multilayer Perceptron MLP -style projections for Computation. Like a Transformer made up of stacked transformer blocks, Mamba is made up of stacked Mamba blocks as above. We would like to understand and motivate the choice of the SSM for sequence transformations. Motivating Mamba - A Throwback to Temple Run Imagine we’re building a Temple Run agent4. It chooses if the runner should move left or right at any time. To successfully pick the correct direction, we need information about our surroundings. Let’s call the collection of relevant information the state. Here the state likely includes your current position and velocity, the position of the nearest obstacle, weather conditions, etc. Claim 1: if you know the current state of the world and how the world is evolving, then you can use this to determine the direction to move. Note that you don’t need to look at the whole screen all the time. You can figure out what will happen to most of the screen by noting that as you run, the obstacles move down the screen. You only need to look at the top of the screen to understand the new information and then simulate the rest. This lends itself to a natural formulation. Let h be the hidden state, relevant knowledge about the world. Also let x be the input, the observation that you get each time. h’ then represents the derivative of the hidden state, i.e. how the state is evolving. We’re trying to predict y, the optimal next move right or left . Now, Claim 1 states that from the hidden state h, h’, and the new observation x, you can figure out y. More concretely, h, the state, can be represented as a differential equation Eq 1a : $h’ t = \mathbf{A}h t + \mathbf{B}x t $ Knowing h allows you to determine your next move y Eq 1b : $y t = \mathbf{C}h t + \mathbf{D}x t $ The system's evolution is determined by its current state and newly acquired observations. A small new observation is enough, as the majority of the state can be inferred by applying known state dynamics to its previous state. That is, most of the screen isn’t new, it’s just a continuation of the previous state's natural downward trajectory. A full understanding of the state would enable optimal selection of the subsequent action, denoted as y. You can learn a lot about the system dynamics by observing the top of the screen. For instance, increased velocity of this upper section suggests an acceleration of the rest of the screen as well, so we can infer that the game is speeding up5. In this way, even if we start off knowing nothing about the game and only have limited observations, it becomes possible to gain a holistic understanding of the screen dynamics fairly rapidly. What’s the State? Here, state refers to the variables that, when combined with the input variables, fully determine the future system behaviour. In theory, once we have the state, there’s nothing else we need to know about the past to predict the future. With this choice of state, the system is converted to a Markov Decision Process. Ideally, the state is a fairly small amount of information which captures the essential properties of the system. That is, the state is a compression of the past6. Discretisation - How To Deal With Living in a Quantised World Okay, great So, given some state and input observation, we have an autoregressive-style system to determine the next action. Amazing In practice though, there’s a little snag here. We’re modelling time as continuous. But in real life, we get new inputs and take new actions at discrete time steps7. We would like to convert this continuous-time differential equation into a discrete-time difference equation. This conversion process is known as discretisation. Discretisation is a well-studied problem in the literature. Mamba uses the Zero-Order Hold ZOH discretisation8. To give an idea of what’s happening morally, consider a naive first-order approximation9. From Equation 1a, we have $h’ t = \mathbf{A}h t + \mathbf{B}x t $ And for small ∆, $h’ t \approx \frac{h t+\Delta - h t }{\Delta}$ by the definition of the derivative. We let: $h t = h t $ and $h {t+1} = h t + \Delta $ and substitute into Equation 1a giving: $h {t+1} - h t \approx \Delta \mathbf{A}h t + \mathbf{B}x t $ $\Rightarrow h {t+1} \approx I + \Delta \mathbf{A} h t + \Delta \mathbf{B} x t$ Hence, after renaming the coefficients and relabelling indices, we have the discrete representations: If you’ve ever looked at an RNN before10 and this feels familiar - trust your instincts: We have some input x, which is combined with the previous hidden state by some transform to give the new hidden state. Then we use the hidden state to calculate the output at each time step. Understanding the SSM Matrices Now, we can interpret the A, B, C, D matrices more intuitively: - A is the transition state matrix. It shows how you transition the current state into the next state. It asks “How should I forget the less relevant parts of the state over time?” - B is mapping the new input into the state, asking “What part of my new input should I remember?”11 - C is mapping the state to the output of the SSM. It asks, “How can I use the state to make a good next prediction?”12 - D is how the new input passes through to the output. It’s a kind of modified skip connection that asks “How can I use the new input in my prediction?” Additionally, ∆ has a nice interpretation - it’s the step size, or what we might call the linger time or the dwell time. For large ∆, you focus more on that token; for small ∆, you skip past the token immediately and don’t include it much in the next state. And that’s it That’s the SSM, our ~drop-in replacement for Attention Communication in the Mamba block. The Computation in the Mamba architecture comes from regular linear projections, non-linearities, and local convolutions. Okay great, that’s the theory - but does this work? Well… Effectiveness vs Efficiency: Attention is Focus, Selectivity is Prioritisation At WWDC ‘97, Steve Jobs famously noted that “focusing is about saying no”. Focus is ruthless prioritisation. It’s common to think about Attention positively as choosing what to notice. In the Steve Jobs sense, we might instead frame Attention negatively as choosing what to discard. There’s a classic intuition pump in Machine Learning known as the Cocktail Party Problem13. Imagine a party with dozens of simultaneous loud conversations: Question: How do we recognise what one person is saying when others are talking at the same time?14 Answer: The brain solves this problem by focusing your “attention” on a particular stimulus and hence drowning out all other sounds as much as possible. Transformers use Dot-Product Attention to focus on the most relevant tokens. A big reason Attention is so great is that you have the potential to look back at everything that ever happened in its context. This is like photographic memory when done right.15 Transformers 🤖 are extremely effective. But they aren’t very efficient. They store everything from the past so that they can look back at tokens with theoretically perfect recall. Traditional RNNs 🔁 are the opposite - they forget a lot, only recalling a small amount in their hidden state and discarding the rest. They are very efficient - their state is small. Yet they are less effective as discarded information cannot be recovered. We’d like something closer to the Pareto frontier of the effectiveness/efficiency tradeoff. Something that’s more effective than traditional RNNs and more efficient than transformers. The Mamba Architecture seems to offer a solution which pushes out the Pareto frontier of effectiveness/efficiency. SSMs are as efficient as RNNs, but we might wonder how effective they are. After all, it seems like they would have a hard time discarding only unnecessary information and keeping everything relevant. If each token is being processed the same way, applying the same A and B matrices as if in a factory assembly line for tokens, there is no context-dependence. We would like the forgetting and remembering matrices A and B respectively to vary and dynamically adapt to inputs. The Selection Mechanism Selectivity allows each token to be transformed into the state in a way that is unique to its own needs. Selectivity is what takes us from vanilla SSM models applying the same A forgetting and B remembering matrices to every input to Mamba, the Selective State Space Model. In regular SSMs, A, B, C and D are learned matrices - that is $\mathbf{A} = \mathbf{A} {\theta}$ etc. where θ represents the learned param