{"slug": "attention-backpropagation-step-by-step-derivation", "title": "Attention Backpropagation: Step by step derivation", "summary": "A blog post derives the backward pass of attention mechanisms step by step, using a concrete example to illustrate gradient computation for Q, K, and V matrices. The derivation builds on FlashAttention and FlashAttention2 papers, focusing on the softmax and matrix multiplication operations in the forward pass.", "body_md": "I recently revisited the FlashAttention[1] and FlashAttention2[2] papers. It is really fun to manually derive the backward pass of the attention. In this blog, I will use a concrete example to illustrate this process and hope it is easy to understand.\n\n# Forward Pass\n\nSo attention[3] involves 3 matrices: $Q$, $K$, $V$. The matrix shape is [batch_size, num_heads, seq_len, head_dim]. Attention is calculated as follows:\n\n\\[\\text{Attention}(Q, K, V) = \\text{softmax}(\\frac{QK^T}{\\sqrt{head\\_dim}})V\\]Let me use a simple example to illustrate this process. We will ignore $batch_size$ and $num_heads$ dimension in this example because the matrix multiplication is on $seq_len$ and $head_dim$ dimensions. And we will also ignore the scaling factor $\\frac{1}{\\sqrt{head_dim}}$ for simplicity.\n\n\\[Q = \\begin{bmatrix} q_{11} & q_{12} & q_{13} \\\\ q_{21} & q_{22} & q_{23} \\\\ q_{31} & q_{32} & q_{33} \\end{bmatrix}\\] \\[K = \\begin{bmatrix} k_{11} & k_{12} & k_{13} \\\\ k_{21} & k_{22} & k_{23} \\\\ k_{31} & k_{32} & k_{33} \\end{bmatrix}\\] \\[V = \\begin{bmatrix} v_{11} & v_{12} & v_{13} \\\\ v_{21} & v_{22} & v_{23} \\\\ v_{31} & v_{32} & v_{33} \\end{bmatrix}\\]So\n\n\\[QK^T = S = \\begin{bmatrix} q_{11}k_{11} + q_{12}k_{21} + q_{13}k_{31} & q_{11}k_{12} + q_{12}k_{22} + q_{13}k_{32} & q_{11}k_{13} + q_{12}k_{23} + q_{13}k_{33} \\\\ q_{21}k_{11} + q_{22}k_{21} + q_{23}k_{31} & q_{21}k_{12} + q_{22}k_{22} + q_{23}k_{32} & q_{21}k_{13} + q_{22}k_{23} + q_{23}k_{33} \\\\ q_{31}k_{11} + q_{32}k_{21} + q_{33}k_{31} & q_{31}k_{12} + q_{32}k_{22} + q_{33}k_{32} & q_{31}k_{13} + q_{32}k_{23} + q_{33}k_{33} \\end{bmatrix} = \\begin{bmatrix} s_{11} & s_{12} & s_{13} \\\\ s_{21} & s_{22} & s_{23} \\\\ s_{31} & s_{32} & s_{33} \\end{bmatrix}\\] \\[P = \\text{softmax}(S) = \\begin{bmatrix} \\frac{exp(s_{11})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} & \\frac{exp(s_{12})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} & \\frac{exp(s_{13})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} \\\\ \\frac{exp(s_{21})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} & \\frac{exp(s_{22})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} & \\frac{exp(s_{23})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} \\\\ \\frac{exp(s_{31})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} & \\frac{exp(s_{32})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} & \\frac{exp(s_{33})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} \\end{bmatrix} = \\begin{bmatrix} p_{11} & p_{12} & p_{13} \\\\ p_{21} & p_{22} & p_{23} \\\\ p_{31} & p_{32} & p_{33} \\end{bmatrix}\\] \\[O = PV = \\begin{bmatrix} p_{11}v_{11} + p_{12}v_{21} + p_{13}v_{31} & p_{11}v_{12} + p_{12}v_{22} + p_{13}v_{32} & p_{11}v_{13} + p_{12}v_{23} + p_{13}v_{33} \\\\ p_{21}v_{11} + p_{22}v_{21} + p_{23}v_{31} & p_{21}v_{12} + p_{22}v_{22} + p_{23}v_{32} & p_{21}v_{13} + p_{22}v_{23} + p_{23}v_{33} \\\\ p_{31}v_{11} + p_{32}v_{21} + p_{33}v_{31} & p_{31}v_{12} + p_{32}v_{22} + p_{33}v_{32} & p_{31}v_{13} + p_{32}v_{23} + p_{33}v_{33} \\end{bmatrix} = \\begin{bmatrix} o_{11} & o_{12} & o_{13} \\\\ o_{21} & o_{22} & o_{23} \\\\ o_{31} & o_{32} & o_{33} \\end{bmatrix}\\]$O$ is the output of the attention.\n\n# Backward Pass\n\nWhen we do backward pass, the input is the partial derivative of loss with respect to $O$.\n\n\\[\\frac{\\partial L}{\\partial O} = \\begin{bmatrix} \\frac{\\partial L}{\\partial o_{11}} & \\frac{\\partial L}{\\partial o_{12}} & \\frac{\\partial L}{\\partial o_{13}} \\\\ \\frac{\\partial L}{\\partial o_{21}} & \\frac{\\partial L}{\\partial o_{22}} & \\frac{\\partial L}{\\partial o_{23}} \\\\ \\frac{\\partial L}{\\partial o_{31}} & \\frac{\\partial L}{\\partial o_{32}} & \\frac{\\partial L}{\\partial o_{33}} \\end{bmatrix}\\]When we use the deep learning framework like Pytorch, Jax, this derivative is automatically computed. And we will use this derivative to compute the gradient of $\\frac{\\partial L}{\\partial Q}$, $\\frac{\\partial L}{\\partial K}$, $\\frac{\\partial L}{\\partial V}$.\n\n## Gradient of $V$ and $P$\n\nThis is most straightforward. Remember that $O = PV$,\n\n\\[O = PV = \\begin{bmatrix} p_{11}v_{11} + p_{12}v_{21} + p_{13}v_{31} & p_{11}v_{12} + p_{12}v_{22} + p_{13}v_{32} & p_{11}v_{13} + p_{12}v_{23} + p_{13}v_{33} \\\\ p_{21}v_{11} + p_{22}v_{21} + p_{23}v_{31} & p_{21}v_{12} + p_{22}v_{22} + p_{23}v_{32} & p_{21}v_{13} + p_{22}v_{23} + p_{23}v_{33} \\\\ p_{31}v_{11} + p_{32}v_{21} + p_{33}v_{31} & p_{31}v_{12} + p_{32}v_{22} + p_{33}v_{32} & p_{31}v_{13} + p_{32}v_{23} + p_{33}v_{33} \\end{bmatrix} = \\begin{bmatrix} o_{11} & o_{12} & o_{13} \\\\ o_{21} & o_{22} & o_{23} \\\\ o_{31} & o_{32} & o_{33} \\end{bmatrix}\\]So for example $\\frac{\\partial L}{\\partial v_{11}}$, it appears in the first column of $O$, so\n\n\\[\\frac{\\partial L}{\\partial v_{11}} = \\frac{\\partial L}{\\partial o_{11}}\\frac{\\partial o_{11}}{\\partial v_{11}} + \\frac{\\partial L}{\\partial o_{21}}\\frac{\\partial o_{21}}{\\partial v_{11}} + \\frac{\\partial L}{\\partial o_{31}}\\frac{\\partial o_{31}}{\\partial v_{11}}\\]Since $o_{11} = p_{11}v_{11} + p_{12}v_{21} + p_{13}v_{31}$,\n\n\\[\\frac{\\partial o_{11}}{\\partial v_{11}} = p_{11}\\] \\[\\frac{\\partial o_{21}}{\\partial v_{11}} = p_{21}\\] \\[\\frac{\\partial o_{31}}{\\partial v_{11}} = p_{31}\\]So\n\n\\[\\frac{\\partial L}{\\partial v_{11}} = \\frac{\\partial L}{\\partial o_{11}}p_{11} + \\frac{\\partial L}{\\partial o_{21}}p_{21} + \\frac{\\partial L}{\\partial o_{31}}p_{31}\\]So\n\n\\[\\frac{\\partial L}{\\partial V} = \\begin{bmatrix} \\frac{\\partial L}{\\partial v_{11}} & \\frac{\\partial L}{\\partial v_{12}} & \\frac{\\partial L}{\\partial v_{13}} \\\\ \\frac{\\partial L}{\\partial v_{21}} & \\frac{\\partial L}{\\partial v_{22}} & \\frac{\\partial L}{\\partial v_{23}} \\\\ \\frac{\\partial L}{\\partial v_{31}} & \\frac{\\partial L}{\\partial v_{32}} & \\frac{\\partial L}{\\partial v_{33}} \\end{bmatrix} = \\begin{bmatrix} p_{11}\\frac{\\partial L}{\\partial o_{11}} + p_{21}\\frac{\\partial L}{\\partial o_{21}} + p_{31}\\frac{\\partial L}{\\partial o_{31}} & p_{11}\\frac{\\partial L}{\\partial o_{12}} + p_{21}\\frac{\\partial L}{\\partial o_{22}} + p_{31}\\frac{\\partial L}{\\partial o_{32}} & p_{11}\\frac{\\partial L}{\\partial o_{13}} + p_{21}\\frac{\\partial L}{\\partial o_{23}} + p_{31}\\frac{\\partial L}{\\partial o_{33}} \\\\ p_{12}\\frac{\\partial L}{\\partial o_{11}} + p_{22}\\frac{\\partial L}{\\partial o_{21}} + p_{32}\\frac{\\partial L}{\\partial o_{31}} & p_{12}\\frac{\\partial L}{\\partial o_{12}} + p_{22}\\frac{\\partial L}{\\partial o_{22}} + p_{32}\\frac{\\partial L}{\\partial o_{32}} & p_{12}\\frac{\\partial L}{\\partial o_{13}} + p_{22}\\frac{\\partial L}{\\partial o_{23}} + p_{32}\\frac{\\partial L}{\\partial o_{33}} \\\\ p_{31}\\frac{\\partial L}{\\partial o_{31}} + p_{32}\\frac{\\partial L}{\\partial o_{32}} + p_{33}\\frac{\\partial L}{\\partial o_{33}} & p_{31}\\frac{\\partial L}{\\partial o_{31}} + p_{32}\\frac{\\partial L}{\\partial o_{32}} + p_{33}\\frac{\\partial L}{\\partial o_{33}} & p_{31}\\frac{\\partial L}{\\partial o_{31}} + p_{32}\\frac{\\partial L}{\\partial o_{32}} + p_{33}\\frac{\\partial L}{\\partial o_{33}} \\end{bmatrix}\\]So\n\n\\[\\frac{\\partial L}{\\partial V} = P^T \\frac{\\partial L}{\\partial O}\\]Similarly,\n\n\\[\\frac{\\partial L}{\\partial P} = \\begin{bmatrix} \\frac{\\partial L}{\\partial p_{11}} & \\frac{\\partial L}{\\partial p_{12}} & \\frac{\\partial L}{\\partial p_{13}} \\\\ \\frac{\\partial L}{\\partial p_{21}} & \\frac{\\partial L}{\\partial p_{22}} & \\frac{\\partial L}{\\partial p_{23}} \\\\ \\frac{\\partial L}{\\partial p_{31}} & \\frac{\\partial L}{\\partial p_{32}} & \\frac{\\partial L}{\\partial p_{33}} \\end{bmatrix} = \\frac{\\partial L}{\\partial O}V^T\\]## Gradient of $S$\n\nTo compute the gradient of $K$ and $Q$, we need to compute the gradient of $S$ first.\n\nRemember that\n\n\\[P = \\text{softmax}(S) = \\begin{bmatrix} \\frac{exp(s_{11})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} & \\frac{exp(s_{12})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} & \\frac{exp(s_{13})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} \\\\ \\frac{exp(s_{21})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} & \\frac{exp(s_{22})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} & \\frac{exp(s_{23})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} \\\\ \\frac{exp(s_{31})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} & \\frac{exp(s_{32})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} & \\frac{exp(s_{33})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} \\end{bmatrix} = \\begin{bmatrix} p_{11} & p_{12} & p_{13} \\\\ p_{21} & p_{22} & p_{23} \\\\ p_{31} & p_{32} & p_{33} \\end{bmatrix}\\]So for example $s_{11}$ appears in the first row of $P$, so\n\n\\[\\frac{\\partial L}{\\partial s_{11}} = \\frac{\\partial L}{\\partial p_{11}}\\frac{\\partial p_{11}}{\\partial s_{11}} + \\frac{\\partial L}{\\partial p_{12}}\\frac{\\partial p_{12}}{\\partial s_{11}} + \\frac{\\partial L}{\\partial p_{13}}\\frac{\\partial p_{13}}{\\partial s_{11}}\\]Since $p_{11} = \\frac{exp(s_{11})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})}$,\n\n\\[\\frac{\\partial p_{11}}{\\partial s_{11}} = \\frac{exp(s_{11})(exp(s_{11}) + exp(s_{12}) + exp(s_{13})) - exp(s_{11})exp(s_{11})}{(exp(s_{11}) + exp(s_{12}) + exp(s_{13}))^2} = \\frac{exp(s_{11})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} - \\frac{exp(s_{11})^2}{(exp(s_{11}) + exp(s_{12}) + exp(s_{13}))^2} = p_{11} - p_{11}^2\\]Since $p_{12} = \\frac{exp(s_{12})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})}$,\n\n\\[\\frac{\\partial p_{12}}{\\partial s_{11}} = \\frac{0 * (exp(s_{11}) + exp(s_{12}) + exp(s_{13})) - exp(s_{12})exp(s_{11})}{(exp(s_{11}) + exp(s_{12}) + exp(s_{13}))^2} = - \\frac{exp(s_{12})exp(s_{11})}{(exp(s_{11}) + exp(s_{12}) + exp(s_{13}))^2} = -p_{11}p_{12}\\]Since $p_{13} = \\frac{exp(s_{13})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})}$,\n\n\\[\\frac{\\partial p_{13}}{\\partial s_{11}} = \\frac{0 * (exp(s_{11}) + exp(s_{12}) + exp(s_{13})) - exp(s_{13})exp(s_{11})}{(exp(s_{11}) + exp(s_{12}) + exp(s_{13}))^2} = - \\frac{exp(s_{13})exp(s_{11})}{(exp(s_{11}) + exp(s_{12}) + exp(s_{13}))^2} = -p_{11}p_{13}\\]So\n\n\\[\\frac{\\partial L}{\\partial s_{11}} = \\frac{\\partial L}{\\partial p_{11}}(p_{11} - p_{11}^2) + \\frac{\\partial L}{\\partial p_{12}}(-p_{11}p_{12}) + \\frac{\\partial L}{\\partial p_{13}}(-p_{11}p_{13})\\]And similarly we could derive that\n\n\\[\\frac{\\partial L}{\\partial s_{12}} = \\frac{\\partial L}{\\partial p_{11}}( - p_{11}p_{12}) + \\frac{\\partial L}{\\partial p_{12}}(p_{12} -p_{12}^2) + \\frac{\\partial L}{\\partial p_{13}}(-p_{12}p_{13})\\] \\[\\frac{\\partial L}{\\partial s_{13}} = \\frac{\\partial L}{\\partial p_{11}}( - p_{11}p_{13}) + \\frac{\\partial L}{\\partial p_{12}}(-p_{12}p_{13}) + \\frac{\\partial L}{\\partial p_{13}}(p_{13} -p_{13}^2)\\]Let’s use $\\frac{\\partial L}{\\partial S_{1}} = (\\frac{\\partial L}{\\partial s_{11}}, \\frac{\\partial L}{\\partial s_{12}}, \\frac{\\partial L}{\\partial s_{13}})$, and $\\frac{\\partial L}{\\partial P_{1}} = (\\frac{\\partial L}{\\partial p{11}}, \\frac{\\partial L}{\\partial p_{12}}, \\frac{\\partial L}{\\partial p_{13}})$, then we have\n\n\\[\\frac{\\partial L}{\\partial S_{1}} = \\frac{\\partial L}{\\partial P_{1}} \\begin{bmatrix} p_{11} * (1-p_{11}) & -p_{11}p_{12} & -p_{11}p_{13} \\\\ -p_{11}p_{12} & p_{22} * (1-p_{22}) & -p_{12}p_{13} \\\\ -p_{11}p_{13} & -p_{12}p_{13} & p_{33} * (1-p_{33}) \\end{bmatrix}\\]Let $P_1 = (p_{11}, p_{12}, p_{13})$, then we have\n\n\\[\\begin{bmatrix} p_{11} * (1-p_{11}) & -p_{11}p_{12} & -p_{11}p_{13} \\\\ -p_{11}p_{12} & p_{22} * (1-p_{22}) & -p_{12}p_{13} \\\\ -p_{11}p_{13} & -p_{12}p_{13} & p_{33} * (1-p_{33}) \\end{bmatrix} = \\begin{bmatrix} p_{11} & 0 & 0 \\\\ 0 & p_{12} & 0 \\\\ 0 & 0 & p_{13} \\end{bmatrix} - P_1^T P_1\\]So\n\n\\[\\frac{\\partial L}{\\partial S_{1}} = \\frac{\\partial L}{\\partial P_{1}} \\circ P_1 - (\\frac{\\partial L}{\\partial P_{1}} P_1^T) P_1\\]where $\\circ$ is the element-wise product. And from the last section for $\\frac{\\partial L}{\\partial P_{1}} = \\frac{\\partial L}{\\partial O_{1}}V^T$ where $\\frac{\\partial L}{\\partial O_{1}} = (\\frac{\\partial L}{\\partial o_{11}}, \\frac{\\partial L}{\\partial o_{12}}, \\frac{\\partial L}{\\partial o_{13}})$. So we have\n\n\\[\\begin{align*} \\frac{\\partial L}{\\partial S_{1}} &= \\frac{\\partial L}{\\partial P_{1}} \\circ P_1 - (\\frac{\\partial L}{\\partial O_{1}}V^T P_1^T) P_1 \\\\ &= \\frac{\\partial L}{\\partial P_{1}} \\circ P_1 - (\\frac{\\partial L}{\\partial O_{1}}(P_1 V)^T) P_1 \\\\ &= \\frac{\\partial L}{\\partial P_{1}} \\circ P_1 - (\\frac{\\partial L}{\\partial O_{1}}O_1^T) P_1 \\\\ &= \\frac{\\partial L}{\\partial P_{1}} \\circ P_1 - ROW\\_SUM(\\frac{\\partial L}{\\partial O_{1}} \\circ O_1) P_1 \\end{align*}\\]So extending this to all rows, we have\n\n\\[\\begin{align*} \\frac{\\partial L}{\\partial S} &= \\frac{\\partial L}{\\partial P} \\circ P - ROW\\_SUM(\\frac{\\partial L}{\\partial O} \\circ O) \\circ P \\end{align*}\\]## Gradient of $Q$ and $K$\n\nSince $S = QK^T$, we have\n\n\\[\\frac{\\partial L}{\\partial Q} = \\frac{\\partial L}{\\partial S}K\\] \\[\\frac{\\partial L}{\\partial K} = (\\frac{\\partial L}{\\partial S})^T Q\\]The derivation is similar to the gradient of $V$ and $P$.\n\n# References\n\n[1]: @misc{dao2022flashattentionfastmemoryefficientexact, title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, author={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher Ré}, year={2022}, eprint={2205.14135}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2205.14135}, }\n\n[2]: @misc{dao2023flashattention2fasterattentionbetter, title={FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning}, author={Tri Dao}, year={2023}, eprint={2307.08691}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2307.08691}, }\n\n[3]: @misc{vaswani2023attentionneed, title={Attention Is All You Need}, author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin}, year={2023}, eprint={1706.03762}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/1706.03762}, }", "url": "https://wpnews.pro/news/attention-backpropagation-step-by-step-derivation", "canonical_source": "https://liyuan24.github.io/writings/attention_backprop.html", "published_at": "2026-06-16 10:16:08+00:00", "updated_at": "2026-06-16 10:49:16.416652+00:00", "lang": "en", "topics": ["machine-learning", "large-language-models", "neural-networks", "ai-research"], "entities": ["FlashAttention", "FlashAttention2", "Pytorch", "Jax"], "alternates": {"html": "https://wpnews.pro/news/attention-backpropagation-step-by-step-derivation", "markdown": "https://wpnews.pro/news/attention-backpropagation-step-by-step-derivation.md", "text": "https://wpnews.pro/news/attention-backpropagation-step-by-step-derivation.txt", "jsonld": "https://wpnews.pro/news/attention-backpropagation-step-by-step-derivation.jsonld"}}