Attention Backpropagation: Step by step derivation A blog post derives the backward pass of attention mechanisms step by step, using a concrete example to illustrate gradient computation for Q, K, and V matrices. The derivation builds on FlashAttention and FlashAttention2 papers, focusing on the softmax and matrix multiplication operations in the forward pass. I recently revisited the FlashAttention 1 and FlashAttention2 2 papers. It is really fun to manually derive the backward pass of the attention. In this blog, I will use a concrete example to illustrate this process and hope it is easy to understand. Forward Pass So attention 3 involves 3 matrices: $Q$, $K$, $V$. The matrix shape is batch size, num heads, seq len, head dim . Attention is calculated as follows: \ \text{Attention} Q, K, V = \text{softmax} \frac{QK^T}{\sqrt{head\ dim}} V\ Let me use a simple example to illustrate this process. We will ignore $batch size$ and $num heads$ dimension in this example because the matrix multiplication is on $seq len$ and $head dim$ dimensions. And we will also ignore the scaling factor $\frac{1}{\sqrt{head dim}}$ for simplicity. \ Q = \begin{bmatrix} q {11} & q {12} & q {13} \\ q {21} & q {22} & q {23} \\ q {31} & q {32} & q {33} \end{bmatrix}\ \ K = \begin{bmatrix} k {11} & k {12} & k {13} \\ k {21} & k {22} & k {23} \\ k {31} & k {32} & k {33} \end{bmatrix}\ \ V = \begin{bmatrix} v {11} & v {12} & v {13} \\ v {21} & v {22} & v {23} \\ v {31} & v {32} & v {33} \end{bmatrix}\ So \ QK^T = S = \begin{bmatrix} q {11}k {11} + q {12}k {21} + q {13}k {31} & q {11}k {12} + q {12}k {22} + q {13}k {32} & q {11}k {13} + q {12}k {23} + q {13}k {33} \\ q {21}k {11} + q {22}k {21} + q {23}k {31} & q {21}k {12} + q {22}k {22} + q {23}k {32} & q {21}k {13} + q {22}k {23} + q {23}k {33} \\ q {31}k {11} + q {32}k {21} + q {33}k {31} & q {31}k {12} + q {32}k {22} + q {33}k {32} & q {31}k {13} + q {32}k {23} + q {33}k {33} \end{bmatrix} = \begin{bmatrix} s {11} & s {12} & s {13} \\ s {21} & s {22} & s {23} \\ s {31} & s {32} & s {33} \end{bmatrix}\ \ P = \text{softmax} S = \begin{bmatrix} \frac{exp s {11} }{exp s {11} + exp s {12} + exp s {13} } & \frac{exp s {12} }{exp s {11} + exp s {12} + exp s {13} } & \frac{exp s {13} }{exp s {11} + exp s {12} + exp s {13} } \\ \frac{exp s {21} }{exp s {21} + exp s {22} + exp s {23} } & \frac{exp s {22} }{exp s {21} + exp s {22} + exp s {23} } & \frac{exp s {23} }{exp s {21} + exp s {22} + exp s {23} } \\ \frac{exp s {31} }{exp s {31} + exp s {32} + exp s {33} } & \frac{exp s {32} }{exp s {31} + exp s {32} + exp s {33} } & \frac{exp s {33} }{exp s {31} + exp s {32} + exp s {33} } \end{bmatrix} = \begin{bmatrix} p {11} & p {12} & p {13} \\ p {21} & p {22} & p {23} \\ p {31} & p {32} & p {33} \end{bmatrix}\ \ O = PV = \begin{bmatrix} p {11}v {11} + p {12}v {21} + p {13}v {31} & p {11}v {12} + p {12}v {22} + p {13}v {32} & p {11}v {13} + p {12}v {23} + p {13}v {33} \\ p {21}v {11} + p {22}v {21} + p {23}v {31} & p {21}v {12} + p {22}v {22} + p {23}v {32} & p {21}v {13} + p {22}v {23} + p {23}v {33} \\ p {31}v {11} + p {32}v {21} + p {33}v {31} & p {31}v {12} + p {32}v {22} + p {33}v {32} & p {31}v {13} + p {32}v {23} + p {33}v {33} \end{bmatrix} = \begin{bmatrix} o {11} & o {12} & o {13} \\ o {21} & o {22} & o {23} \\ o {31} & o {32} & o {33} \end{bmatrix}\ $O$ is the output of the attention. Backward Pass When we do backward pass, the input is the partial derivative of loss with respect to $O$. \ \frac{\partial L}{\partial O} = \begin{bmatrix} \frac{\partial L}{\partial o {11}} & \frac{\partial L}{\partial o {12}} & \frac{\partial L}{\partial o {13}} \\ \frac{\partial L}{\partial o {21}} & \frac{\partial L}{\partial o {22}} & \frac{\partial L}{\partial o {23}} \\ \frac{\partial L}{\partial o {31}} & \frac{\partial L}{\partial o {32}} & \frac{\partial L}{\partial o {33}} \end{bmatrix}\ When we use the deep learning framework like Pytorch, Jax, this derivative is automatically computed. And we will use this derivative to compute the gradient of $\frac{\partial L}{\partial Q}$, $\frac{\partial L}{\partial K}$, $\frac{\partial L}{\partial V}$. Gradient of $V$ and $P$ This is most straightforward. Remember that $O = PV$, \ O = PV = \begin{bmatrix} p {11}v {11} + p {12}v {21} + p {13}v {31} & p {11}v {12} + p {12}v {22} + p {13}v {32} & p {11}v {13} + p {12}v {23} + p {13}v {33} \\ p {21}v {11} + p {22}v {21} + p {23}v {31} & p {21}v {12} + p {22}v {22} + p {23}v {32} & p {21}v {13} + p {22}v {23} + p {23}v {33} \\ p {31}v {11} + p {32}v {21} + p {33}v {31} & p {31}v {12} + p {32}v {22} + p {33}v {32} & p {31}v {13} + p {32}v {23} + p {33}v {33} \end{bmatrix} = \begin{bmatrix} o {11} & o {12} & o {13} \\ o {21} & o {22} & o {23} \\ o {31} & o {32} & o {33} \end{bmatrix}\ So for example $\frac{\partial L}{\partial v {11}}$, it appears in the first column of $O$, so \ \frac{\partial L}{\partial v {11}} = \frac{\partial L}{\partial o {11}}\frac{\partial o {11}}{\partial v {11}} + \frac{\partial L}{\partial o {21}}\frac{\partial o {21}}{\partial v {11}} + \frac{\partial L}{\partial o {31}}\frac{\partial o {31}}{\partial v {11}}\ Since $o {11} = p {11}v {11} + p {12}v {21} + p {13}v {31}$, \ \frac{\partial o {11}}{\partial v {11}} = p {11}\ \ \frac{\partial o {21}}{\partial v {11}} = p {21}\ \ \frac{\partial o {31}}{\partial v {11}} = p {31}\ So \ \frac{\partial L}{\partial v {11}} = \frac{\partial L}{\partial o {11}}p {11} + \frac{\partial L}{\partial o {21}}p {21} + \frac{\partial L}{\partial o {31}}p {31}\ So \ \frac{\partial L}{\partial V} = \begin{bmatrix} \frac{\partial L}{\partial v {11}} & \frac{\partial L}{\partial v {12}} & \frac{\partial L}{\partial v {13}} \\ \frac{\partial L}{\partial v {21}} & \frac{\partial L}{\partial v {22}} & \frac{\partial L}{\partial v {23}} \\ \frac{\partial L}{\partial v {31}} & \frac{\partial L}{\partial v {32}} & \frac{\partial L}{\partial v {33}} \end{bmatrix} = \begin{bmatrix} p {11}\frac{\partial L}{\partial o {11}} + p {21}\frac{\partial L}{\partial o {21}} + p {31}\frac{\partial L}{\partial o {31}} & p {11}\frac{\partial L}{\partial o {12}} + p {21}\frac{\partial L}{\partial o {22}} + p {31}\frac{\partial L}{\partial o {32}} & p {11}\frac{\partial L}{\partial o {13}} + p {21}\frac{\partial L}{\partial o {23}} + p {31}\frac{\partial L}{\partial o {33}} \\ p {12}\frac{\partial L}{\partial o {11}} + p {22}\frac{\partial L}{\partial o {21}} + p {32}\frac{\partial L}{\partial o {31}} & p {12}\frac{\partial L}{\partial o {12}} + p {22}\frac{\partial L}{\partial o {22}} + p {32}\frac{\partial L}{\partial o {32}} & p {12}\frac{\partial L}{\partial o {13}} + p {22}\frac{\partial L}{\partial o {23}} + p {32}\frac{\partial L}{\partial o {33}} \\ p {31}\frac{\partial L}{\partial o {31}} + p {32}\frac{\partial L}{\partial o {32}} + p {33}\frac{\partial L}{\partial o {33}} & p {31}\frac{\partial L}{\partial o {31}} + p {32}\frac{\partial L}{\partial o {32}} + p {33}\frac{\partial L}{\partial o {33}} & p {31}\frac{\partial L}{\partial o {31}} + p {32}\frac{\partial L}{\partial o {32}} + p {33}\frac{\partial L}{\partial o {33}} \end{bmatrix}\ So \ \frac{\partial L}{\partial V} = P^T \frac{\partial L}{\partial O}\ Similarly, \ \frac{\partial L}{\partial P} = \begin{bmatrix} \frac{\partial L}{\partial p {11}} & \frac{\partial L}{\partial p {12}} & \frac{\partial L}{\partial p {13}} \\ \frac{\partial L}{\partial p {21}} & \frac{\partial L}{\partial p {22}} & \frac{\partial L}{\partial p {23}} \\ \frac{\partial L}{\partial p {31}} & \frac{\partial L}{\partial p {32}} & \frac{\partial L}{\partial p {33}} \end{bmatrix} = \frac{\partial L}{\partial O}V^T\ Gradient of $S$ To compute the gradient of $K$ and $Q$, we need to compute the gradient of $S$ first. Remember that \ P = \text{softmax} S = \begin{bmatrix} \frac{exp s {11} }{exp s {11} + exp s {12} + exp s {13} } & \frac{exp s {12} }{exp s {11} + exp s {12} + exp s {13} } & \frac{exp s {13} }{exp s {11} + exp s {12} + exp s {13} } \\ \frac{exp s {21} }{exp s {21} + exp s {22} + exp s {23} } & \frac{exp s {22} }{exp s {21} + exp s {22} + exp s {23} } & \frac{exp s {23} }{exp s {21} + exp s {22} + exp s {23} } \\ \frac{exp s {31} }{exp s {31} + exp s {32} + exp s {33} } & \frac{exp s {32} }{exp s {31} + exp s {32} + exp s {33} } & \frac{exp s {33} }{exp s {31} + exp s {32} + exp s {33} } \end{bmatrix} = \begin{bmatrix} p {11} & p {12} & p {13} \\ p {21} & p {22} & p {23} \\ p {31} & p {32} & p {33} \end{bmatrix}\ So for example $s {11}$ appears in the first row of $P$, so \ \frac{\partial L}{\partial s {11}} = \frac{\partial L}{\partial p {11}}\frac{\partial p {11}}{\partial s {11}} + \frac{\partial L}{\partial p {12}}\frac{\partial p {12}}{\partial s {11}} + \frac{\partial L}{\partial p {13}}\frac{\partial p {13}}{\partial s {11}}\ Since $p {11} = \frac{exp s {11} }{exp s {11} + exp s {12} + exp s {13} }$, \ \frac{\partial p {11}}{\partial s {11}} = \frac{exp s {11} exp s {11} + exp s {12} + exp s {13} - exp s {11} exp s {11} }{ exp s {11} + exp s {12} + exp s {13} ^2} = \frac{exp s {11} }{exp s {11} + exp s {12} + exp s {13} } - \frac{exp s {11} ^2}{ exp s {11} + exp s {12} + exp s {13} ^2} = p {11} - p {11}^2\ Since $p {12} = \frac{exp s {12} }{exp s {11} + exp s {12} + exp s {13} }$, \ \frac{\partial p {12}}{\partial s {11}} = \frac{0 exp s {11} + exp s {12} + exp s {13} - exp s {12} exp s {11} }{ exp s {11} + exp s {12} + exp s {13} ^2} = - \frac{exp s {12} exp s {11} }{ exp s {11} + exp s {12} + exp s {13} ^2} = -p {11}p {12}\ Since $p {13} = \frac{exp s {13} }{exp s {11} + exp s {12} + exp s {13} }$, \ \frac{\partial p {13}}{\partial s {11}} = \frac{0 exp s {11} + exp s {12} + exp s {13} - exp s {13} exp s {11} }{ exp s {11} + exp s {12} + exp s {13} ^2} = - \frac{exp s {13} exp s {11} }{ exp s {11} + exp s {12} + exp s {13} ^2} = -p {11}p {13}\ So \ \frac{\partial L}{\partial s {11}} = \frac{\partial L}{\partial p {11}} p {11} - p {11}^2 + \frac{\partial L}{\partial p {12}} -p {11}p {12} + \frac{\partial L}{\partial p {13}} -p {11}p {13} \ And similarly we could derive that \ \frac{\partial L}{\partial s {12}} = \frac{\partial L}{\partial p {11}} - p {11}p {12} + \frac{\partial L}{\partial p {12}} p {12} -p {12}^2 + \frac{\partial L}{\partial p {13}} -p {12}p {13} \ \ \frac{\partial L}{\partial s {13}} = \frac{\partial L}{\partial p {11}} - p {11}p {13} + \frac{\partial L}{\partial p {12}} -p {12}p {13} + \frac{\partial L}{\partial p {13}} p {13} -p {13}^2 \ Let’s use $\frac{\partial L}{\partial S {1}} = \frac{\partial L}{\partial s {11}}, \frac{\partial L}{\partial s {12}}, \frac{\partial L}{\partial s {13}} $, and $\frac{\partial L}{\partial P {1}} = \frac{\partial L}{\partial p{11}}, \frac{\partial L}{\partial p {12}}, \frac{\partial L}{\partial p {13}} $, then we have \ \frac{\partial L}{\partial S {1}} = \frac{\partial L}{\partial P {1}} \begin{bmatrix} p {11} 1-p {11} & -p {11}p {12} & -p {11}p {13} \\ -p {11}p {12} & p {22} 1-p {22} & -p {12}p {13} \\ -p {11}p {13} & -p {12}p {13} & p {33} 1-p {33} \end{bmatrix}\ Let $P 1 = p {11}, p {12}, p {13} $, then we have \ \begin{bmatrix} p {11} 1-p {11} & -p {11}p {12} & -p {11}p {13} \\ -p {11}p {12} & p {22} 1-p {22} & -p {12}p {13} \\ -p {11}p {13} & -p {12}p {13} & p {33} 1-p {33} \end{bmatrix} = \begin{bmatrix} p {11} & 0 & 0 \\ 0 & p {12} & 0 \\ 0 & 0 & p {13} \end{bmatrix} - P 1^T P 1\ So \ \frac{\partial L}{\partial S {1}} = \frac{\partial L}{\partial P {1}} \circ P 1 - \frac{\partial L}{\partial P {1}} P 1^T P 1\ where $\circ$ is the element-wise product. And from the last section for $\frac{\partial L}{\partial P {1}} = \frac{\partial L}{\partial O {1}}V^T$ where $\frac{\partial L}{\partial O {1}} = \frac{\partial L}{\partial o {11}}, \frac{\partial L}{\partial o {12}}, \frac{\partial L}{\partial o {13}} $. So we have \ \begin{align } \frac{\partial L}{\partial S {1}} &= \frac{\partial L}{\partial P {1}} \circ P 1 - \frac{\partial L}{\partial O {1}}V^T P 1^T P 1 \\ &= \frac{\partial L}{\partial P {1}} \circ P 1 - \frac{\partial L}{\partial O {1}} P 1 V ^T P 1 \\ &= \frac{\partial L}{\partial P {1}} \circ P 1 - \frac{\partial L}{\partial O {1}}O 1^T P 1 \\ &= \frac{\partial L}{\partial P {1}} \circ P 1 - ROW\ SUM \frac{\partial L}{\partial O {1}} \circ O 1 P 1 \end{align }\ So extending this to all rows, we have \ \begin{align } \frac{\partial L}{\partial S} &= \frac{\partial L}{\partial P} \circ P - ROW\ SUM \frac{\partial L}{\partial O} \circ O \circ P \end{align }\ Gradient of $Q$ and $K$ Since $S = QK^T$, we have \ \frac{\partial L}{\partial Q} = \frac{\partial L}{\partial S}K\ \ \frac{\partial L}{\partial K} = \frac{\partial L}{\partial S} ^T Q\ The derivation is similar to the gradient of $V$ and $P$. References 1 : @misc{dao2022flashattentionfastmemoryefficientexact, title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, author={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher Ré}, year={2022}, eprint={2205.14135}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2205.14135}, } 2 : @misc{dao2023flashattention2fasterattentionbetter, title={FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning}, author={Tri Dao}, year={2023}, eprint={2307.08691}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2307.08691}, } 3 : @misc{vaswani2023attentionneed, title={Attention Is All You Need}, author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin}, year={2023}, eprint={1706.03762}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/1706.03762}, }