{"slug": "the-annotated-pytorch-training-loop", "title": "The annotated PyTorch training loop", "summary": "A detailed breakdown of the PyTorch training loop reveals common placement errors that silently break training, such as moving the model to a device after optimizer creation or calling gradient clipping before backward pass. The annotated loop explains each line's purpose and the consequences of misordering operations, helping developers avoid convergence failures and memory issues.", "body_md": "What each line does and what breaks if you move it.\n\nBuilding a PyTorch training loop is fairly straightforward, but getting everything in the right place and in the right order can feel surprisingly fragile. There are loads of moving parts and after the most basic errors are fixed, most of the other mistakes can be pretty hard to spot. Training runs will fail to converge, produce incorrect results, or consume excessive memory if lines are misplaced.\n\nThe sections below will go through each operation in sequence, explaining exactly how to write each section, and all the common mistakes to watch out for.\nDistributed training, FSDP, and multi-GPU setups are out of scope here, but we'll come back to that in a future essay.\n*(The animation above was produced by running the loop on synthetic data and capturing the decision boundary at each epoch.)*\n\nLet's look, first of all, at the complete training loop. You don't need to understand or memorise it yet, just get a feel for the structure.\n\n``` python\n1import torch\n2import torch.nn as nn\n3from torch.utils.data import DataLoader, TensorDataset\n4\n5# --- data ---\n6dataset = TensorDataset(X_train, y_train)\n7loader  = DataLoader(dataset, batch_size=64, shuffle=True)\n8\n9# --- model, loss, optimiser ---\n10model     = MLP(in_features=2, hidden=128, out_features=3)\n11criterion = nn.CrossEntropyLoss()\n12optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)\n13scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=100)\n14\n15# --- loop ---\n16for epoch in range(100):\n17    model.train()\n18    for X_batch, y_batch in loader:\n19        optimiser.zero_grad()\n20        logits = model(X_batch)\n21        loss   = criterion(logits, y_batch)\n22        loss.backward()\n23        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n24        optimiser.step()\n25    scheduler.step()\n26\n27    model.eval()\n28    with torch.no_grad():\n29        val_logits = model(X_val)\n30        val_loss   = criterion(val_logits, y_val)\n```\n\nNow let's go through each line and understand what it does, and how not to break it. We'll start with some of the common mistakes.\n\nHere are some of the most common failures, and how you can break the training loop by getting the placement a little bit wrong. The reason to memorise these is that none of them will raise an exception, over time you'll get a sense for what kind of errors to look for in your training runs, but for the first few times this crib sheet will help you out.\n\n| Line | Wrong position | What breaks |\n|---|---|---|\n`model.to(device)` | After `optimiser = ...` | When a dtype conversion is combined (e.g. `model.half()` ), `nn.Module.to()` allocates new `nn.Parameter` objects; the optimiser holds references to the discarded originals and applies updates to them instead. |\n`optimiser.zero_grad()` | After `loss.backward()` | Gradients from multiple batches accumulate. Update uses their sum, not the current batch alone. |\n`clip_grad_norm_()` | Before `loss.backward()` | `.grad` is empty. The call is a no-op. |\n`clip_grad_norm_()` | After `optimiser.step()` | Clips gradients already applied. No effect. |\n`scheduler.step()` | Inside batch loop | LR decays `len(loader)` times per epoch instead of once. |\nOmit `model.train()` after `model.eval()` | — | Dropout disabled, BatchNorm frozen. The model trains in eval mode without error. |\nOmit `torch.no_grad()` during validation | — | Autograd graph builds on every validation batch. Memory grows until OOM. |\nLog `loss` instead of `loss.item()` | — | Pins the computation graph in memory for the duration of the logging call. |\n\nNow let's go through each of these in detail.\n\nThere are two parts to the data pipeline in PyTorch: the `Dataset`\n\nand the `DataLoader`\n\n. The `Dataset`\n\nis just a Python object that implements `__len__`\n\n(how many elements are in the dataset) and `__getitem__`\n\n(which, unsurprisingly, gets an item). It can be a simple wrapper around tensors, or it can load data from disk on demand.\n\nThe `DataLoader`\n\nwraps a dataset and produces batches.\n\nEach pass through the full dataset is one epoch. With `shuffle=True`\n\n, examples are presented in a different order each epoch.\n\n```\n1dataset = TensorDataset(X_train, y_train)\n2loader  = DataLoader(\n3    dataset,\n4    batch_size=64,\n5    shuffle=True,\n6    num_workers=2,\n7    pin_memory=True,\n8    persistent_workers=True,\n9)\n```\n\n`TensorDataset`\n\npairs input and label tensors by index. Indexing with `dataset[i]`\n\nreturns `(X[i], y[i])`\n\n. `DataLoader`\n\ncalls `__getitem__`\n\nrepeatedly, collates the results into batches, and optionally hands work to background worker processes.\n\n** num_workers** spawns separate processes that prefetch batches in parallel with GPU compute. The main process blocks on\n\n`.next()`\n\nonly if a batch is not yet ready. Zero workers means the main process does all loading, which often bottlenecks GPU utilisation on data-heavy tasks. Two to four workers is practical, but the right number depends on CPU count and I/O speed.** pin_memory=True** allocates batch tensors in pinned host memory. The GPU DMA engine can transfer directly from pinned memory without first copying through the kernel buffer, reducing host-to-device transfer time. It only helps when\n\n`num_workers > 0`\n\nand you're transferring to CUDA.** persistent_workers=True** keeps worker processes alive between epochs. Without it, workers are respawned at the start of each epoch, adding fork overhead that becomes measurable at large worker counts.\n\n** drop_last=True** discards the final batch if it is smaller than\n\n`batch_size`\n\n. BatchNorm statistics computed from a batch of two or three samples are noisy so dropping the remainder avoids this. Small cost in terms of dropping the data, but it is often worth it for stability.Smaller batches produce noisier gradient estimates, which acts as implicit regularisation. Larger batches use more GPU memory but allow more parallelism. One of the most important efficiencies is knowing that powers of two align with tensor core tile sizes (typically 16×16 or 8×16 depending on dtype), so making sure batch sizes and layer dimensions are set to multiples of 8 or 16 is a good idea.\n\n** .to(device)** moves a tensor to the target device. For tensors, it is not in-place: it returns a new tensor and leaves the original unchanged. For example,\n\n`X_batch.to('cuda')`\n\nreturns a new tensor on GPU; `X_batch`\n\nitself remains on CPU.Setting seeds before constructing the model and loader gives the same results on every run. This is essential in order to reproduce experiments, and make sure model behavior is deterministic. The main areas this impacts the model is in the data loader, and in the initialisation of the model weights.\n\n```\n1import random\n2import numpy as np\n3\n4def set_seed(seed: int = 42):\n5    torch.manual_seed(seed)\n6    torch.cuda.manual_seed_all(seed)\n7    np.random.seed(seed)\n8    random.seed(seed)\n9    torch.backends.cudnn.deterministic = True\n10    torch.backends.cudnn.benchmark     = False\n11\n12set_seed(42)\n```\n\n`torch.manual_seed`\n\nseeds the CPU generator. `torch.cuda.manual_seed_all`\n\nseeds every GPU. NumPy and Python's `random`\n\nare independent RNGs that PyTorch does not touch, be careful of this, you might need another random seed for them.\n\n`cudnn.deterministic = True`\n\nforces cuDNN to use deterministic convolution algorithms. Some cuDNN kernels are non-deterministic by default for throughput. The deterministic alternatives are slightly slower, but practically it shouldn't matter much during development.\n\n`cudnn.benchmark = False`\n\nmust be paired with `deterministic = True`\n\n. When `benchmark = True`\n\n, cuDNN profiles several algorithms per input shape and picks the fastest, a process that itself varies between runs. Fixing it to `False`\n\nmakes sure you always get the same results.\n\nWhen `num_workers > 0`\n\n, each DataLoader worker has its own RNG state, seeded by the OS when it forks the processes. To make worker randomness reproducible you should pass a `generator`\n\nand a `worker_init_fn`\n\n:\n\n```\n1g = torch.Generator()\n2g.manual_seed(42)\n3\n4loader = DataLoader(\n5    dataset,\n6    batch_size=64,\n7    shuffle=True,\n8    num_workers=2,\n9    generator=g,\n10    worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id),\n11)\n```\n\n`nn.Module`\n\nprovides parameter tracking, device movement, train/eval mode switching, and serialisation. Each instance needs an `__init__`\n\n(to register all the submodules) and `forward`\n\n(to define the computation).\n\n`nn.ReLU`\n\nis actually doing.`nn.Linear`\n\nas an `nn.Module`\n\n, registering weights and bias as parameters.\n\n``` python\n1class MLP(nn.Module):\n2    def __init__(self, in_features, hidden, out_features):\n3        super().__init__()\n4        self.net = nn.Sequential(\n5            nn.Linear(in_features, hidden),\n6            nn.ReLU(),\n7            nn.Linear(hidden, hidden),\n8            nn.ReLU(),\n9            nn.Linear(hidden, out_features),\n10        )\n11\n12    def forward(self, x):\n13        return self.net(x)\n14\n15device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n16model  = MLP(in_features=2, hidden=128, out_features=3).to(device)\n```\n\n`super().__init__()`\n\nis required: it initialises the module registry. Assigning submodules or parameter tensors as attributes in `__init__`\n\nregisters them automatically. Unregistered plain attributes are excluded from parameter iteration, serialisation, and device movement.\n\n** .to(device)** moves all registered parameters and buffers to the target device in-place. Call it before constructing the optimiser. For a device-only move,\n\n`nn.Module.to()`\n\nmodifies each parameter's `.data`\n\nattribute in-place and the optimiser's references remain correct. When a dtype conversion is combined (e.g. `.half().to(device)`\n\n), `nn.Module.to()`\n\nallocates new `nn.Parameter`\n\nobjects and replaces them in the module's internal registry; the optimiser, constructed before this, retains references to the originals and applies updates to them instead of the converted parameters.** register_buffer** is for tensors that should follow the module (move with\n\n`.to(device)`\n\n, appear in `state_dict`\n\n) but are not trained parameters. BatchNorm's `running_mean`\n\nand `running_var`\n\nare buffers, as is the attention mask in a transformer.** model.parameters()** returns all leaf tensors with\n\n`requires_grad=True`\n\n. `model.named_parameters()`\n\nCalling `model(x)`\n\ninvokes `__call__`\n\n, not `forward`\n\ndirectly. This is a classic misconception for new users. That wrapper runs forward hooks, then the computation, then backward hooks. The distinction matters when using hooks for instrumentation or gradient modification.\n\n** torch.compile(model)** (Only available in PyTorch 2.0+) traces the forward pass and emits optimised Triton/CUDA kernels via the backend. It fuses adjacent element-wise operations, reducing memory traffic. The first forward pass is slow (needs to run compilation) but subsequent ones are 10–30% faster on GPU, sometimes higher on inference-heavy workloads.\n\nThe PyTorch Parameter\nThere are two things we care about in the `nn.Parameter`\n\nclass: the values stored in `.data`\n\nand the `.grad`\n\nattribute. The backward pass accumulates gradients additively into `.grad`\n\n; the optimiser reads `.grad`\n\nto compute the update and writes the result back to `.data`\n\n.\n\nYou can also use the parameters without gradients by setting `requires_grad=False`\n\n. This is useful for freezing layers, or for inference-only models, or non-trainable parameters.\n\n`Dropout`\n\napplies random masking and rescales, and `BatchNorm`\n\ncomputes statistics from the current batch rather than its stored running estimates.\n\n```\n1model.train()                          # [1]\n2for X_batch, y_batch in loader:\n3    ...\n```\n\n`model.train()`\n\nsets a flag on the module and all sub-modules. Two common layers read it during their forward pass.\n\n** Dropout** samples a Bernoulli mask with probability\n\n`p`\n\nof zeroing each element, then scales surviving activations by to preserve expected magnitude. In eval mode it acts as an identity function. The reason to invert dropout at train time is that it means you don't need to rescale at inference.** BatchNorm** during training computes the mean and variance of the current batch, normalises against them, and updates its stored\n\n`running_mean`\n\nand `running_var`\n\nvia exponential moving average. In eval mode it uses those stored estimates instead of the batch statistics.`LayerNorm`\n\n, `GELU`\n\n, `ReLU`\n\n, and most other layers are unaffected by the training flag and behave identically in both modes. The flag usually only matters if your architecture contains `Dropout`\n\nor any `BatchNorm`\n\nvariant.\n\nOmitting `model.train()`\n\nafter validation does not raise an exception. The model trains with dropout disabled and BatchNorm using frozen statistics, both of which alter the effective learning dynamics.\n\nIn eval mode the larger memory reduction comes from `torch.no_grad()`\n\n, which disables graph construction entirely: no `grad_fn`\n\nis attached to output tensors and no intermediate activations are stored for backward. Used together during validation, they roughly halve memory usage compared to a training-mode forward pass.\n\n`.grad`\n\nbefore the next forward pass. PyTorch accumulates gradients additively. Without this call, gradients from the previous batch add to the current one.\n\n```\n1for X_batch, y_batch in loader:\n2    optimiser.zero_grad()              # [2]\n3    ...\n```\n\n`optimiser.zero_grad()`\n\nresets every parameter's `.grad`\n\nattribute before the next forward pass. PyTorch accumulates gradients additively: each `.backward()`\n\ncall adds to `.grad`\n\n, it does not replace it. This behaviour enables gradient accumulation: summing gradients over multiple micro-batches before stepping is equivalent to training with a proportionally larger batch.\n\n```\n1# gradient accumulation — effective batch size = batch_size × ACCUM_STEPS\n2ACCUM_STEPS = 4\n3for i, (X_batch, y_batch) in enumerate(loader):\n4    logits = model(X_batch)\n5    loss   = criterion(logits, y_batch) / ACCUM_STEPS   # scale to match mean reduction\n6    loss.backward()                                       # adds to .grad each time\n7    if (i + 1) % ACCUM_STEPS == 0:\n8        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n9        optimiser.step()\n10        optimiser.zero_grad()\n```\n\nAccumulating gradients across multiple small batches and stepping once is mathematically equivalent to training with a batch four times larger, assuming the loss is mean-reduced. At scale, where a full batch does not fit in GPU memory, accumulation is the standard approach.\n\n** zero_grad(set_to_none=True)** sets\n\n`.grad`\n\nto `None`\n\nrather than filling it with zeros. It is slightly faster (skips the zero-fill pass) and uses less memory (None tensors are deallocated). This is the default in PyTorch 2.0 onwards. Be careful with custom backwards passes where you are reading `.grad`\n\ndirectly.`forward()`\n\nand builds the autograd computation graph. Every operation on a tensor with `requires_grad=True`\n\nis recorded: what the inputs were, what the output was, and how to compute the local vector-Jacobian product for backward.\n\n```\n1logits = model(X_batch)                # [3]\n```\n\nDuring the forward pass, PyTorch builds a computation graph that records every operation on tensors with `requires_grad=True`\n\n. Each node in the graph represents an operation, storing the input tensors, the output tensor, and the function needed to compute gradients during the backward pass. This allows PyTorch to automatically compute gradients for all parameters with respect to a scalar loss.\n\nCrucially this is a dynamic graph, it's constructed on-the-fly during the forward pass. This means that the graph can change from one forward pass to the next, allowing for more flexible model architectures, such as those with conditional branches or loops, and variable input sizes.\n\nThis graph is a directed acyclic graph (DAG) from the loss back to the leaf parameters. During the forward pass, activations from every layer are kept alive in memory for use in the backward pass. For a batch of size and a network with layers, the memory cost scales with in the naive case.\n\n**Gradient checkpointing** (`torch.utils.checkpoint.checkpoint`\n\n) reduces this. Instead of storing all intermediate activations, it re-runs the forward pass of checkpointed segments during backward to recompute them on demand. Peak activation memory scales with the number of uncheckpointed segments rather than total depth, at the cost of roughly one additional forward pass per checkpointed segment.\n\n``` python\n1from torch.utils.checkpoint import checkpoint\n2\n3def forward(self, x):\n4    x = checkpoint(self.block1, x, use_reentrant=False)\n5    x = checkpoint(self.block2, x, use_reentrant=False)\n6    return self.head(x)\n```\n\nInside `torch.no_grad()`\n\n, no graph is constructed and no activations are stored. Memory usage drops roughly by half compared to a training-mode forward pass.\n\n`CrossEntropyLoss`\n\n, this is log-softmax followed by negative log-likelihood, averaged over the batch. The result is a scalar tensor with a `grad_fn`\n\n, still connected to the graph.\n\n```\n1loss = criterion(logits, y_batch)      # [4]\n```\n\nThe Cross-Entropy loss is a standard choice for multi-class classification problems. It measures the difference between the predicted probability distribution (from the model's logits) and the true distribution (the one-hot encoded labels).\n\n`nn.CrossEntropyLoss`\n\nis a composition of two operations: `LogSoftmax`\n\nfollowed by `NLLLoss`\n\n. The combined form is more numerically stable than computing them separately, because it avoids materialising the softmax probabilities and then taking their log.\n\nThe underlying computation uses the log-sum-exp trick to prevent overflow:\n\nSubtracting before exponentiation keeps the values in a safe range. The loss for a single example is then:\n\nand the batch loss is the mean over examples.\n\n**Common arguments:**\n\n`weight`\n\naccepts a 1D tensor of per-class weights, applied to each sample's loss contribution. Use this for class imbalance i.e. upweight rare classes.\n\n`label_smoothing=0.1`\n\ndistributes a fraction of the probability mass uniformly across all classes rather than concentrating it on the target. The effective target distribution becomes . It prevents overconfident pseudo-probabilities and is standard in modern image and language model training.\n\n`ignore_index=-100`\n\nmasks positions with that label from the loss. Used in sequence modeling to exclude padding and masked tokens.\n\n`reduction='mean'`\n\ndivides by batch size. `'sum'`\n\ndoes not. Switching between them shifts the effective loss scale and therefore the effective learning rate.\n\nLogging loss with `loss.item()`\n\nextracts a Python float and detaches it from the graph. Logging the tensor directly pins the graph in memory for the duration of the logging call. If you're not careful, this can lead to runaway memory growth during training.\n\n`.grad`\n\non each parameter via the chain rule. Does not modify weights.\n\n```\n1loss.backward()                        # [5]\n```\n\n`backward()`\n\nimplements reverse-mode automatic differentiation (backpropagation). For a scalar output, a single backward pass computes gradients with respect to all parameters. Forward-mode differentiation requires separate passes, one per parameter. At the parameter counts typical of modern networks, reverse mode is the only practical option.\n\nThe backward pass walks the computation graph from the loss in reverse topological order. At each node, it applies the backward function (the vector-Jacobian product for that operation) and propagates the result to the node's inputs. At leaf parameters, the result accumulates into `.grad`\n\n.\n\nAs we said before, the `.grad`\n\nattribute accumulates additively. If `.backward()`\n\nis called without zeroing gradients first, the new gradients add to whatever was already in `.grad`\n\n. Gradient accumulation relies on this behaviour, and we use `zero_grad()`\n\nto reset it.\n\n** retain_graph=True** prevents the graph from being deallocated after backward. Normally the graph is freed immediately because it is assumed to be used once. You need\n\n`retain_graph=True`\n\nwhen calling backward multiple times through the same graph (for example, in a GAN where you backpropagate through a shared encoder for both the discriminator and generator losses).** create_graph=True** allows differentiation through the backward pass itself. The backward computation becomes differentiable, enabling higher-order derivatives: Hessian-vector products, MAML-style meta-gradients, or second-order optimisers.\n\n`max_norm`\n\n, rescales all gradients proportionally. Called after `backward()`\n\n, before `step()`\n\n.\n\n```\n1torch.nn.utils.clip_grad_norm_(        # [6]\n2    model.parameters(), max_norm=1.0\n3)\n```\n\n`clip_grad_norm_`\n\ncomputes the global norm across all parameters:\n\nIf , every gradient tensor is multiplied by . Relative direction is preserved; only the magnitude is bounded.\n\nGradient spikes are common in transformer training. Attention softmax can saturate, producing near-one-hot distributions, which propagates large gradient norms to earlier parameters via the chain rule.\n\nGlobal norm clipping is standard practice: max_norm=1.0 is used in the original GPT-2 paper, most subsequent language model work, and many vision transformer papers. It is less commonly necessary for small MLPs on well-conditioned data.\n\n`clip_grad_value_`\n\nclips individual gradient components to rather than the global norm. It does not preserve gradient direction and is less commonly used.\n\nPlacing clipping before `backward()`\n\nhas no effect (`.grad`\n\nis empty). Placing it after `step()`\n\nclips gradients that have already been applied to weights.\n\n`.grad`\n\non every parameter and applies the update rule. Moment estimates are updated. Weights change here and only here.\n\n```\n1optimiser.step()                       # [7]\n```\n\nParameter values change in `step()`\n\nand nowhere else. For plain SGD with momentum:\n\nAdam maintains per-parameter estimates of the first moment (gradient mean) and second moment (gradient variance):\n\nBoth estimates are biased toward zero at initialisation (they start at zero, and early values of and underestimate the true moments). Bias correction accounts for this:\n\nThe weight update is then:\n\nDefault hyperparameters: , , . The per-parameter adaptive learning rate makes Adam robust to varying gradient scales across parameters, which matters in deep networks where early and late layers often have gradients of very different magnitudes.\n\n**AdamW** decouples weight decay from the gradient update. Standard Adam with L2 regularisation adds a gradient term before the update, which then gets scaled by the adaptive step size. AdamW applies weight decay directly to the weights: . The two are not equivalent; AdamW gives more predictable effective regularisation and is now standard for transformer training.\n\n** fused=True** (CUDA only): fuses the entire parameter update into a single CUDA kernel per parameter group, avoiding the Python loop and multiple kernel launches. It runs approximately 30–50% faster than the default at large parameter counts.\n\n** foreach=True**: uses batched\n\n`torch._foreach_*`\n\noperations that process all parameters together in vectorised Python loops. It is intermediate between the default and fused in speed, and available on CPU.The optimiser stores moment state for every parameter. For Adam on a 7B-parameter model, that is 7B float32 tensors for and 7B for , roughly 56GB of optimiser state at full precision. At that scale, training typically uses 8-bit optimisers (bitsandbytes) or shards state across devices (ZeRO).\n\n`optimiser.step()`\n\n, outside the batch loop.\n\n```\n1    scheduler.step()                   # [8]\n```\n\nThe scheduler modifies the `lr`\n\nfield of each parameter group in the optimiser. The most common mistake is calling it inside the batch loop:\n\n```\n1# wrong — lr decays len(loader) times per epoch instead of once\n2for X_batch, y_batch in loader:\n3    optimiser.step()\n4    scheduler.step()\n5\n6# correct\n7for X_batch, y_batch in loader:\n8    optimiser.step()\n9scheduler.step()\n```\n\n**Cosine annealing** decays the learning rate from `eta_max`\n\nto `eta_min`\n\nover `T_max`\n\nepochs:\n\nIn modern large model training, the standard schedule combines a linear warmup with cosine decay. A warmup phase (typically 1–5% of total steps) limits the update magnitude while parameters are far from their trained values. Cosine decay reduces the rate for the remaining steps.\n\n```\n1# linear warmup + cosine decay\n2# (using HuggingFace's implementation as reference)\n3from transformers import get_cosine_schedule_with_warmup\n4\n5scheduler = get_cosine_schedule_with_warmup(\n6    optimiser,\n7    num_warmup_steps=100,\n8    num_training_steps=10_000,\n9)\n10scheduler.step()   # called per step, not per epoch, with this scheduler\n```\n\n`ReduceLROnPlateau`\n\nis the exception to the no-argument rule. It takes a metric value and reduces the learning rate only when the metric has stopped improving for `patience`\n\nepochs. It must be called with the validation loss: `scheduler.step(val_loss)`\n\n.\n\n`model.eval()`\n\nchanges layer behaviour. `torch.no_grad()`\n\nstops graph construction. They are independent operations; you need both for validation.\n\n```\n1model.eval()                           # [9]\n2with torch.no_grad():                  # [10]\n3    val_logits = model(X_val)\n4    val_loss   = criterion(val_logits, y_val)\n```\n\n`model.eval()`\n\nand `torch.no_grad()`\n\nare independent operations.\n\n`model.eval()`\n\nsets `self.training = False`\n\non every module. **BatchNorm** switches from computing batch statistics to using its stored `running_mean`\n\nand `running_var`\n\n. **Dropout** switches from sampling Bernoulli masks to the identity function. No other standard layers are affected.\n\n`torch.no_grad()`\n\ndisables the construction of the autograd graph entirely. No `grad_fn`\n\nis attached to tensors produced inside the context; no intermediate activations are saved. This roughly halves memory usage compared to a training-mode forward pass and makes computation slightly faster (fewer bookkeeping operations per op).\n\nThe two are independent choices:\n\n** torch.inference_mode()** is a stricter form of\n\n`no_grad()`\n\n. Tensors created inside the context have `is_inference() == True`\n\n, which prevents them from participating in a backward pass even if they escape the context manager. It removes a few additional checks in the autograd engine and runs about 10% faster than `no_grad()`\n\n.Track training and validation metrics using `loss.item()`\n\n, not `loss`\n\n. Calling `.item()`\n\nextracts a Python float and detaches from the graph; holding a reference to the tensor keeps the full backward graph alive until the next call.\n\n```\n1for epoch in range(NUM_EPOCHS):\n2    model.train()\n3    running_loss = 0.0\n4    for X_batch, y_batch in loader:\n5        optimiser.zero_grad()\n6        loss = criterion(model(X_batch), y_batch)\n7        loss.backward()\n8        optimiser.step()\n9        running_loss += loss.item()\n10    train_loss = running_loss / len(loader)\n11\n12    model.eval()\n13    with torch.no_grad():\n14        val_logits = model(X_val)\n15        val_loss   = criterion(val_logits, y_val).item()\n16        val_acc    = (val_logits.argmax(1) == y_val).float().mean().item()\n17\n18    print(f'epoch {epoch:3d}  train {train_loss:.4f}  val {val_loss:.4f}  acc {val_acc:.3f}')\n```\n\n`torch.save`\n\nwrites a checkpoint file; `torch.load`\n\nreads it back. Checkpoints allow a training run to survive crashes and preemption.\n\n```\n1# save\n2torch.save({\n3    'epoch':                epoch,\n4    'model_state_dict':     model.state_dict(),\n5    'optimiser_state_dict': optimiser.state_dict(),\n6    'scheduler_state_dict': scheduler.state_dict(),\n7    'val_loss':             val_loss,\n8}, 'checkpoint.pt')\n9\n10# resume\n11checkpoint = torch.load('checkpoint.pt', map_location=device)\n12model.load_state_dict(checkpoint['model_state_dict'])\n13optimiser.load_state_dict(checkpoint['optimiser_state_dict'])\n14scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n15start_epoch = checkpoint['epoch'] + 1\n```\n\n`state_dict()`\n\nreturns an ordered dictionary of parameter and buffer tensors. It does not include the class definition, so the model class must be defined in scope before calling `load_state_dict`\n\n.\n\nThe optimiser state must be saved alongside the model to resume training correctly. For Adam, the state includes per-parameter moment estimates and , which take several hundred steps to warm up from zero. Resuming without them is equivalent to restarting the optimiser cold, which typically produces a loss spike at the resume point.\n\n`map_location=device`\n\nhandles the common case where the checkpoint was saved on a different GPU than the one loading it.\n\nThe standard pattern saves only when validation loss improves, so the saved weights correspond to the best-generalising epoch rather than the final one:\n\n```\n1best_val_loss = float('inf')\n2\n3for epoch in range(NUM_EPOCHS):\n4    # ... training loop ...\n5\n6    model.eval()\n7    with torch.no_grad():\n8        val_loss = criterion(model(X_val), y_val).item()\n9\n10    if val_loss < best_val_loss:\n11        best_val_loss = val_loss\n12        torch.save(model.state_dict(), 'best_model.pt')\n```\n\nTo restore the best model after training: `model.load_state_dict(torch.load('best_model.pt', map_location=device))`\n\n.\n\nWhile we're working through the nitty-gritty of the training loop, it's worth spending a few minutes thinking about some of the basic GPU optimisation techniques. None of these exactly count as \"melting the hardware\" but they're almost all entirely free speed.\n\nPut the model and data on the *same* GPU, minimises the data transfer overhead.\n\n```\n1device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n2\n3model = MLP(in_features=2, hidden=128, out_features=3).to(device)\n4# construct optimiser AFTER moving model — it captures parameter references\n5optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)\n6\n7# per-batch: move data to device\n8for X_batch, y_batch in loader:\n9    X_batch = X_batch.to(device, non_blocking=True)\n10    y_batch = y_batch.to(device, non_blocking=True)\n```\n\n`non_blocking=True`\n\ninitiates the host-to-device transfer asynchronously and returns immediately. The GPU can begin work from a previous batch while the new batch is still transferring. The overlap only helps when `pin_memory=True`\n\nin the DataLoader; unpinned memory cannot be transferred asynchronously.\n\nModern GPUs have dedicated tensor cores that execute float16 (and bfloat16) matrix multiplications 4–8× faster than float32 on the same hardware. Mixed precision training keeps weights in float32 but runs the forward and backward passes in float16.\n\nThis makes the computation more efficient, but float16 has a much smaller dynamic range than float32, so small gradient values underflow to zero, producing incorrect updates. The compromise is to maintain the weights in float32, where accumulated rounding errors remain within the float32 dynamic range.\n\n```\n1scaler = torch.amp.GradScaler('cuda')\n2\n3for X_batch, y_batch in loader:\n4    optimiser.zero_grad()\n5\n6    with torch.amp.autocast('cuda', dtype=torch.float16):\n7        logits = model(X_batch)\n8        loss   = criterion(logits, y_batch)\n9\n10    scaler.scale(loss).backward()           # backward in float16\n11    scaler.unscale_(optimiser)              # unscale before clipping\n12    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n13    scaler.step(optimiser)                  # step only if no inf/nan\n14    scaler.update()                         # adjust scale factor\n```\n\n`torch.amp.autocast`\n\nand `torch.amp.GradScaler`\n\nare the current API (PyTorch 2.0+). The older `from torch.cuda.amp import autocast, GradScaler`\n\nstill works but is deprecated.\n\n**Why GradScaler:** float16 has a much smaller dynamic range than float32. Small gradient values underflow to zero, producing incorrect updates. GradScaler multiplies the loss by a large scale factor (typically ) before backward, keeping gradient magnitudes in the float16 range. Before the optimiser step, gradients are divided back to their true scale. If an overflow is detected (inf or nan in any gradient), the step is skipped and the scale factor is reduced.\n\n**bfloat16** (`dtype=torch.bfloat16`\n\n) has the same exponent range as float32 but fewer mantissa bits, so it does not underflow and does not need GradScaler. It is the preferred dtype on recent hardware (A100, H100, TPUs).\n\n```\n1with torch.amp.autocast('cuda', dtype=torch.bfloat16):\n2    logits = model(X_batch)\n3    loss   = criterion(logits, y_batch)\n4\n5loss.backward()                # no scaler needed\n6optimiser.step()\n```\n\nFor GPU training, the bottleneck is often not the GPU but the data pipeline feeding it. The DataLoader with `num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2`\n\nkeeps the GPU fed without stalling. `prefetch_factor=2`\n\n(the default) means each worker pre-fetches two batches beyond what has been consumed.\n\n`torch.backends.cudnn.benchmark = True`\n\nruns a short profiling pass on the first batch to determine the fastest convolution algorithm for your specific input shape. Subsequent batches use that cached choice. Do not use this if input shapes vary between batches; the profiling overhead is incurred again for each new shape.\n\n```\n1model = torch.compile(model)\n```\n\nAvailable in PyTorch 2.0 and above, `compile`\n\ntraces the forward pass and emits optimised Triton kernels via the Inductor backend. The main benefit is operation fusion: adjacent element-wise operations (activation + dropout + residual add, for example) are merged into a single kernel, reducing memory reads and writes. Typical speedups are 10–30% on GPU training, higher for inference workloads and architectures with many small ops.\n\n`mode='max-autotune'`\n\nruns additional profiling to choose the best kernel for each operation. Compilation takes longer, but the resulting model is faster, and the cost amortises over long training runs.\n\nThe first forward pass triggers compilation and is slow. `torch._dynamo.reset()`\n\nclears the cache if you need to recompile (e.g., after changing the model structure).\n\n```\n1device = torch.device('cuda')\n2model  = torch.compile(MLP(...).to(device))\n3\n4torch.backends.cudnn.benchmark = True\n5\n6for epoch in range(NUM_EPOCHS):\n7    model.train()\n8    for X_batch, y_batch in loader:\n9        X_batch = X_batch.to(device, non_blocking=True)\n10        y_batch = y_batch.to(device, non_blocking=True)\n11\n12        optimiser.zero_grad(set_to_none=True)\n13\n14        with torch.amp.autocast('cuda', dtype=torch.bfloat16):\n15            logits = model(X_batch)\n16            loss   = criterion(logits, y_batch)\n17\n18        loss.backward()                                        # no scaler: bf16 doesn't underflow\n19        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n20        optimiser.step()\n21\n22    scheduler.step()\n23\n24    model.eval()\n25    with torch.inference_mode():\n26        val_logits = model(X_val.to(device))\n27        val_loss   = criterion(val_logits, y_val.to(device))\n```\n\nThe same loop, with every line referenced.\n\n`(N, features)`\n\n. For the spiral dataset, N=2000, features=2.`(N,)`\n\n. Not one-hot: `CrossEntropyLoss`\n\nexpects indices directly.`TensorDataset`\n\npairs inputs and labels by index. `__getitem__`\n\nreturns `(X[i], y[i])`\n\n.`DataLoader`\n\nhandles batching, shuffling each epoch, and optional parallel prefetching via `num_workers`\n\n.`nn.Module`\n\nsubclass. Register submodules in `__init__`\n\n; define the computation in `forward`\n\n.`.to(device)`\n\nmoves all parameters and buffers. Do this before constructing the optimiser.`CrossEntropyLoss`\n\n= LogSoftmax + NLLLoss. Pass raw logits, not softmax outputs.`model.parameters()`\n\nsupplies the tensors to optimise.`scheduler.step()`\n\nadjusts `lr`\n\nin the optimiser's param groups.`model.train()`\n\n: dropout masks active, batchnorm uses batch statistics.`zero_grad()`\n\n: clears `.grad`\n\n. PyTorch accumulates by default; without this, gradients compound across batches.`grad_fn`\n\nis the entry point for backward.`backward()`\n\n: reverse-mode AD. Populates `.grad`\n\non every leaf parameter. Weights unchanged.`optimiser.step()`\n\n: reads `.grad`\n\n, applies Adam update, updates moment estimates. Weights change.`scheduler.step()`\n\n: adjusts lr. Once per epoch, after `optimiser.step()`\n\n, outside the batch loop.`model.eval()`\n\n: dropout pass-through, batchnorm uses running statistics.`torch.no_grad()`\n\n(or `inference_mode()`\n\n): no graph construction. Faster, lower memory.\n\n``` python\n1import torch\n2import torch.nn as nn\n3from torch.utils.data import DataLoader, TensorDataset\n4\n5# ── data ──────────────────────────────────────────────────────────────────────\n6X_train = torch.randn(2000, 2)                                     # [1]\n7y_train = make_labels(X_train)                                     # [2]\n8\n9dataset = TensorDataset(X_train, y_train)                          # [3]\n10loader  = DataLoader(dataset, batch_size=64, shuffle=True,         # [4]\n11                     num_workers=2, pin_memory=True)\n12\n13# ── model ──────────────────────────────────────────────────────────────────────\n14class MLP(nn.Module):                                              # [5]\n15    def __init__(self, in_features, hidden, out_features):\n16        super().__init__()\n17        self.net = nn.Sequential(\n18            nn.Linear(in_features, hidden), nn.ReLU(),\n19            nn.Linear(hidden, hidden),      nn.ReLU(),\n20            nn.Linear(hidden, out_features),\n21        )\n22    def forward(self, x):\n23        return self.net(x)\n24\n25model = MLP(in_features=2, hidden=128, out_features=3).to(device)  # [6]\n26\n27# ── loss, optimiser, scheduler ────────────────────────────────────────────────\n28criterion = nn.CrossEntropyLoss()                                  # [7]\n29optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)          # [8]\n30scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(            # [9]\n31    optimiser, T_max=100\n32)\n33\n34# ── training loop ─────────────────────────────────────────────────────────────\n35for epoch in range(100):\n36\n37    model.train()                                                   # [10]\n38\n39    for X_batch, y_batch in loader:\n40\n41        optimiser.zero_grad()                                       # [11]\n42\n43        logits = model(X_batch)                                     # [12]\n44\n45        loss = criterion(logits, y_batch)                           # [13]\n46\n47        loss.backward()                                             # [14]\n48\n49        torch.nn.utils.clip_grad_norm_(                             # [15]\n50            model.parameters(), max_norm=1.0\n51        )\n52\n53        optimiser.step()                                            # [16]\n54\n55    scheduler.step()                                                # [17]\n56\n57    # ── validation ────────────────────────────────────────────────────────────\n58    model.eval()                                                    # [18]\n59    with torch.no_grad():                                           # [19]\n60        val_logits = model(X_val)\n61        val_loss   = criterion(val_logits, y_val)\n62        val_acc    = (val_logits.argmax(1) == y_val).float().mean()\n```\n\n", "url": "https://wpnews.pro/news/the-annotated-pytorch-training-loop", "canonical_source": "https://idlemachines.co.uk/essays/pytorch-training-loop", "published_at": "2026-06-22 23:44:59+00:00", "updated_at": "2026-06-25 19:42:41.258202+00:00", "lang": "en", "topics": ["machine-learning", "developer-tools"], "entities": ["PyTorch", "Adam", "CosineAnnealingLR", "DataLoader", "TensorDataset", "CrossEntropyLoss", "MLP"], "alternates": {"html": "https://wpnews.pro/news/the-annotated-pytorch-training-loop", "markdown": "https://wpnews.pro/news/the-annotated-pytorch-training-loop.md", "text": "https://wpnews.pro/news/the-annotated-pytorch-training-loop.txt", "jsonld": "https://wpnews.pro/news/the-annotated-pytorch-training-loop.jsonld"}}