# The annotated PyTorch training loop

> Source: <https://idlemachines.co.uk/essays/pytorch-training-loop>
> Published: 2026-06-22 23:44:59+00:00

What each line does and what breaks if you move it.

Building a PyTorch training loop is fairly straightforward, but getting everything in the right place and in the right order can feel surprisingly fragile. There are loads of moving parts and after the most basic errors are fixed, most of the other mistakes can be pretty hard to spot. Training runs will fail to converge, produce incorrect results, or consume excessive memory if lines are misplaced.

The sections below will go through each operation in sequence, explaining exactly how to write each section, and all the common mistakes to watch out for.
Distributed training, FSDP, and multi-GPU setups are out of scope here, but we'll come back to that in a future essay.
*(The animation above was produced by running the loop on synthetic data and capturing the decision boundary at each epoch.)*

Let's look, first of all, at the complete training loop. You don't need to understand or memorise it yet, just get a feel for the structure.

``` python
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader, TensorDataset
4
5# --- data ---
6dataset = TensorDataset(X_train, y_train)
7loader  = DataLoader(dataset, batch_size=64, shuffle=True)
8
9# --- model, loss, optimiser ---
10model     = MLP(in_features=2, hidden=128, out_features=3)
11criterion = nn.CrossEntropyLoss()
12optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
13scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=100)
14
15# --- loop ---
16for epoch in range(100):
17    model.train()
18    for X_batch, y_batch in loader:
19        optimiser.zero_grad()
20        logits = model(X_batch)
21        loss   = criterion(logits, y_batch)
22        loss.backward()
23        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
24        optimiser.step()
25    scheduler.step()
26
27    model.eval()
28    with torch.no_grad():
29        val_logits = model(X_val)
30        val_loss   = criterion(val_logits, y_val)
```

Now let's go through each line and understand what it does, and how not to break it. We'll start with some of the common mistakes.

Here are some of the most common failures, and how you can break the training loop by getting the placement a little bit wrong. The reason to memorise these is that none of them will raise an exception, over time you'll get a sense for what kind of errors to look for in your training runs, but for the first few times this crib sheet will help you out.

| Line | Wrong position | What breaks |
|---|---|---|
`model.to(device)` | After `optimiser = ...` | When a dtype conversion is combined (e.g. `model.half()` ), `nn.Module.to()` allocates new `nn.Parameter` objects; the optimiser holds references to the discarded originals and applies updates to them instead. |
`optimiser.zero_grad()` | After `loss.backward()` | Gradients from multiple batches accumulate. Update uses their sum, not the current batch alone. |
`clip_grad_norm_()` | Before `loss.backward()` | `.grad` is empty. The call is a no-op. |
`clip_grad_norm_()` | After `optimiser.step()` | Clips gradients already applied. No effect. |
`scheduler.step()` | Inside batch loop | LR decays `len(loader)` times per epoch instead of once. |
Omit `model.train()` after `model.eval()` | — | Dropout disabled, BatchNorm frozen. The model trains in eval mode without error. |
Omit `torch.no_grad()` during validation | — | Autograd graph builds on every validation batch. Memory grows until OOM. |
Log `loss` instead of `loss.item()` | — | Pins the computation graph in memory for the duration of the logging call. |

Now let's go through each of these in detail.

There are two parts to the data pipeline in PyTorch: the `Dataset`

and the `DataLoader`

. The `Dataset`

is just a Python object that implements `__len__`

(how many elements are in the dataset) and `__getitem__`

(which, unsurprisingly, gets an item). It can be a simple wrapper around tensors, or it can load data from disk on demand.

The `DataLoader`

wraps a dataset and produces batches.

Each pass through the full dataset is one epoch. With `shuffle=True`

, examples are presented in a different order each epoch.

```
1dataset = TensorDataset(X_train, y_train)
2loader  = DataLoader(
3    dataset,
4    batch_size=64,
5    shuffle=True,
6    num_workers=2,
7    pin_memory=True,
8    persistent_workers=True,
9)
```

`TensorDataset`

pairs input and label tensors by index. Indexing with `dataset[i]`

returns `(X[i], y[i])`

. `DataLoader`

calls `__getitem__`

repeatedly, collates the results into batches, and optionally hands work to background worker processes.

** num_workers** spawns separate processes that prefetch batches in parallel with GPU compute. The main process blocks on

`.next()`

only if a batch is not yet ready. Zero workers means the main process does all loading, which often bottlenecks GPU utilisation on data-heavy tasks. Two to four workers is practical, but the right number depends on CPU count and I/O speed.** pin_memory=True** allocates batch tensors in pinned host memory. The GPU DMA engine can transfer directly from pinned memory without first copying through the kernel buffer, reducing host-to-device transfer time. It only helps when

`num_workers > 0`

and you're transferring to CUDA.** persistent_workers=True** keeps worker processes alive between epochs. Without it, workers are respawned at the start of each epoch, adding fork overhead that becomes measurable at large worker counts.

** drop_last=True** discards the final batch if it is smaller than

`batch_size`

. BatchNorm statistics computed from a batch of two or three samples are noisy so dropping the remainder avoids this. Small cost in terms of dropping the data, but it is often worth it for stability.Smaller batches produce noisier gradient estimates, which acts as implicit regularisation. Larger batches use more GPU memory but allow more parallelism. One of the most important efficiencies is knowing that powers of two align with tensor core tile sizes (typically 16×16 or 8×16 depending on dtype), so making sure batch sizes and layer dimensions are set to multiples of 8 or 16 is a good idea.

** .to(device)** moves a tensor to the target device. For tensors, it is not in-place: it returns a new tensor and leaves the original unchanged. For example,

`X_batch.to('cuda')`

returns a new tensor on GPU; `X_batch`

itself remains on CPU.Setting seeds before constructing the model and loader gives the same results on every run. This is essential in order to reproduce experiments, and make sure model behavior is deterministic. The main areas this impacts the model is in the data loader, and in the initialisation of the model weights.

```
1import random
2import numpy as np
3
4def set_seed(seed: int = 42):
5    torch.manual_seed(seed)
6    torch.cuda.manual_seed_all(seed)
7    np.random.seed(seed)
8    random.seed(seed)
9    torch.backends.cudnn.deterministic = True
10    torch.backends.cudnn.benchmark     = False
11
12set_seed(42)
```

`torch.manual_seed`

seeds the CPU generator. `torch.cuda.manual_seed_all`

seeds every GPU. NumPy and Python's `random`

are independent RNGs that PyTorch does not touch, be careful of this, you might need another random seed for them.

`cudnn.deterministic = True`

forces cuDNN to use deterministic convolution algorithms. Some cuDNN kernels are non-deterministic by default for throughput. The deterministic alternatives are slightly slower, but practically it shouldn't matter much during development.

`cudnn.benchmark = False`

must be paired with `deterministic = True`

. When `benchmark = True`

, cuDNN profiles several algorithms per input shape and picks the fastest, a process that itself varies between runs. Fixing it to `False`

makes sure you always get the same results.

When `num_workers > 0`

, each DataLoader worker has its own RNG state, seeded by the OS when it forks the processes. To make worker randomness reproducible you should pass a `generator`

and a `worker_init_fn`

:

```
1g = torch.Generator()
2g.manual_seed(42)
3
4loader = DataLoader(
5    dataset,
6    batch_size=64,
7    shuffle=True,
8    num_workers=2,
9    generator=g,
10    worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id),
11)
```

`nn.Module`

provides parameter tracking, device movement, train/eval mode switching, and serialisation. Each instance needs an `__init__`

(to register all the submodules) and `forward`

(to define the computation).

`nn.ReLU`

is actually doing.`nn.Linear`

as an `nn.Module`

, registering weights and bias as parameters.

``` python
1class MLP(nn.Module):
2    def __init__(self, in_features, hidden, out_features):
3        super().__init__()
4        self.net = nn.Sequential(
5            nn.Linear(in_features, hidden),
6            nn.ReLU(),
7            nn.Linear(hidden, hidden),
8            nn.ReLU(),
9            nn.Linear(hidden, out_features),
10        )
11
12    def forward(self, x):
13        return self.net(x)
14
15device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16model  = MLP(in_features=2, hidden=128, out_features=3).to(device)
```

`super().__init__()`

is required: it initialises the module registry. Assigning submodules or parameter tensors as attributes in `__init__`

registers them automatically. Unregistered plain attributes are excluded from parameter iteration, serialisation, and device movement.

** .to(device)** moves all registered parameters and buffers to the target device in-place. Call it before constructing the optimiser. For a device-only move,

`nn.Module.to()`

modifies each parameter's `.data`

attribute in-place and the optimiser's references remain correct. When a dtype conversion is combined (e.g. `.half().to(device)`

), `nn.Module.to()`

allocates new `nn.Parameter`

objects and replaces them in the module's internal registry; the optimiser, constructed before this, retains references to the originals and applies updates to them instead of the converted parameters.** register_buffer** is for tensors that should follow the module (move with

`.to(device)`

, appear in `state_dict`

) but are not trained parameters. BatchNorm's `running_mean`

and `running_var`

are buffers, as is the attention mask in a transformer.** model.parameters()** returns all leaf tensors with

`requires_grad=True`

. `model.named_parameters()`

Calling `model(x)`

invokes `__call__`

, not `forward`

directly. This is a classic misconception for new users. That wrapper runs forward hooks, then the computation, then backward hooks. The distinction matters when using hooks for instrumentation or gradient modification.

** torch.compile(model)** (Only available in PyTorch 2.0+) traces the forward pass and emits optimised Triton/CUDA kernels via the backend. It fuses adjacent element-wise operations, reducing memory traffic. The first forward pass is slow (needs to run compilation) but subsequent ones are 10–30% faster on GPU, sometimes higher on inference-heavy workloads.

The PyTorch Parameter
There are two things we care about in the `nn.Parameter`

class: the values stored in `.data`

and the `.grad`

attribute. The backward pass accumulates gradients additively into `.grad`

; the optimiser reads `.grad`

to compute the update and writes the result back to `.data`

.

You can also use the parameters without gradients by setting `requires_grad=False`

. This is useful for freezing layers, or for inference-only models, or non-trainable parameters.

`Dropout`

applies random masking and rescales, and `BatchNorm`

computes statistics from the current batch rather than its stored running estimates.

```
1model.train()                          # [1]
2for X_batch, y_batch in loader:
3    ...
```

`model.train()`

sets a flag on the module and all sub-modules. Two common layers read it during their forward pass.

** Dropout** samples a Bernoulli mask with probability

`p`

of zeroing each element, then scales surviving activations by to preserve expected magnitude. In eval mode it acts as an identity function. The reason to invert dropout at train time is that it means you don't need to rescale at inference.** BatchNorm** during training computes the mean and variance of the current batch, normalises against them, and updates its stored

`running_mean`

and `running_var`

via exponential moving average. In eval mode it uses those stored estimates instead of the batch statistics.`LayerNorm`

, `GELU`

, `ReLU`

, and most other layers are unaffected by the training flag and behave identically in both modes. The flag usually only matters if your architecture contains `Dropout`

or any `BatchNorm`

variant.

Omitting `model.train()`

after validation does not raise an exception. The model trains with dropout disabled and BatchNorm using frozen statistics, both of which alter the effective learning dynamics.

In eval mode the larger memory reduction comes from `torch.no_grad()`

, which disables graph construction entirely: no `grad_fn`

is attached to output tensors and no intermediate activations are stored for backward. Used together during validation, they roughly halve memory usage compared to a training-mode forward pass.

`.grad`

before the next forward pass. PyTorch accumulates gradients additively. Without this call, gradients from the previous batch add to the current one.

```
1for X_batch, y_batch in loader:
2    optimiser.zero_grad()              # [2]
3    ...
```

`optimiser.zero_grad()`

resets every parameter's `.grad`

attribute before the next forward pass. PyTorch accumulates gradients additively: each `.backward()`

call adds to `.grad`

, it does not replace it. This behaviour enables gradient accumulation: summing gradients over multiple micro-batches before stepping is equivalent to training with a proportionally larger batch.

```
1# gradient accumulation — effective batch size = batch_size × ACCUM_STEPS
2ACCUM_STEPS = 4
3for i, (X_batch, y_batch) in enumerate(loader):
4    logits = model(X_batch)
5    loss   = criterion(logits, y_batch) / ACCUM_STEPS   # scale to match mean reduction
6    loss.backward()                                       # adds to .grad each time
7    if (i + 1) % ACCUM_STEPS == 0:
8        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
9        optimiser.step()
10        optimiser.zero_grad()
```

Accumulating gradients across multiple small batches and stepping once is mathematically equivalent to training with a batch four times larger, assuming the loss is mean-reduced. At scale, where a full batch does not fit in GPU memory, accumulation is the standard approach.

** zero_grad(set_to_none=True)** sets

`.grad`

to `None`

rather than filling it with zeros. It is slightly faster (skips the zero-fill pass) and uses less memory (None tensors are deallocated). This is the default in PyTorch 2.0 onwards. Be careful with custom backwards passes where you are reading `.grad`

directly.`forward()`

and builds the autograd computation graph. Every operation on a tensor with `requires_grad=True`

is recorded: what the inputs were, what the output was, and how to compute the local vector-Jacobian product for backward.

```
1logits = model(X_batch)                # [3]
```

During the forward pass, PyTorch builds a computation graph that records every operation on tensors with `requires_grad=True`

. Each node in the graph represents an operation, storing the input tensors, the output tensor, and the function needed to compute gradients during the backward pass. This allows PyTorch to automatically compute gradients for all parameters with respect to a scalar loss.

Crucially this is a dynamic graph, it's constructed on-the-fly during the forward pass. This means that the graph can change from one forward pass to the next, allowing for more flexible model architectures, such as those with conditional branches or loops, and variable input sizes.

This graph is a directed acyclic graph (DAG) from the loss back to the leaf parameters. During the forward pass, activations from every layer are kept alive in memory for use in the backward pass. For a batch of size and a network with layers, the memory cost scales with in the naive case.

**Gradient checkpointing** (`torch.utils.checkpoint.checkpoint`

) reduces this. Instead of storing all intermediate activations, it re-runs the forward pass of checkpointed segments during backward to recompute them on demand. Peak activation memory scales with the number of uncheckpointed segments rather than total depth, at the cost of roughly one additional forward pass per checkpointed segment.

``` python
1from torch.utils.checkpoint import checkpoint
2
3def forward(self, x):
4    x = checkpoint(self.block1, x, use_reentrant=False)
5    x = checkpoint(self.block2, x, use_reentrant=False)
6    return self.head(x)
```

Inside `torch.no_grad()`

, no graph is constructed and no activations are stored. Memory usage drops roughly by half compared to a training-mode forward pass.

`CrossEntropyLoss`

, this is log-softmax followed by negative log-likelihood, averaged over the batch. The result is a scalar tensor with a `grad_fn`

, still connected to the graph.

```
1loss = criterion(logits, y_batch)      # [4]
```

The Cross-Entropy loss is a standard choice for multi-class classification problems. It measures the difference between the predicted probability distribution (from the model's logits) and the true distribution (the one-hot encoded labels).

`nn.CrossEntropyLoss`

is a composition of two operations: `LogSoftmax`

followed by `NLLLoss`

. The combined form is more numerically stable than computing them separately, because it avoids materialising the softmax probabilities and then taking their log.

The underlying computation uses the log-sum-exp trick to prevent overflow:

Subtracting before exponentiation keeps the values in a safe range. The loss for a single example is then:

and the batch loss is the mean over examples.

**Common arguments:**

`weight`

accepts a 1D tensor of per-class weights, applied to each sample's loss contribution. Use this for class imbalance i.e. upweight rare classes.

`label_smoothing=0.1`

distributes a fraction of the probability mass uniformly across all classes rather than concentrating it on the target. The effective target distribution becomes . It prevents overconfident pseudo-probabilities and is standard in modern image and language model training.

`ignore_index=-100`

masks positions with that label from the loss. Used in sequence modeling to exclude padding and masked tokens.

`reduction='mean'`

divides by batch size. `'sum'`

does not. Switching between them shifts the effective loss scale and therefore the effective learning rate.

Logging loss with `loss.item()`

extracts a Python float and detaches it from the graph. Logging the tensor directly pins the graph in memory for the duration of the logging call. If you're not careful, this can lead to runaway memory growth during training.

`.grad`

on each parameter via the chain rule. Does not modify weights.

```
1loss.backward()                        # [5]
```

`backward()`

implements reverse-mode automatic differentiation (backpropagation). For a scalar output, a single backward pass computes gradients with respect to all parameters. Forward-mode differentiation requires separate passes, one per parameter. At the parameter counts typical of modern networks, reverse mode is the only practical option.

The backward pass walks the computation graph from the loss in reverse topological order. At each node, it applies the backward function (the vector-Jacobian product for that operation) and propagates the result to the node's inputs. At leaf parameters, the result accumulates into `.grad`

.

As we said before, the `.grad`

attribute accumulates additively. If `.backward()`

is called without zeroing gradients first, the new gradients add to whatever was already in `.grad`

. Gradient accumulation relies on this behaviour, and we use `zero_grad()`

to reset it.

** retain_graph=True** prevents the graph from being deallocated after backward. Normally the graph is freed immediately because it is assumed to be used once. You need

`retain_graph=True`

when calling backward multiple times through the same graph (for example, in a GAN where you backpropagate through a shared encoder for both the discriminator and generator losses).** create_graph=True** allows differentiation through the backward pass itself. The backward computation becomes differentiable, enabling higher-order derivatives: Hessian-vector products, MAML-style meta-gradients, or second-order optimisers.

`max_norm`

, rescales all gradients proportionally. Called after `backward()`

, before `step()`

.

```
1torch.nn.utils.clip_grad_norm_(        # [6]
2    model.parameters(), max_norm=1.0
3)
```

`clip_grad_norm_`

computes the global norm across all parameters:

If , every gradient tensor is multiplied by . Relative direction is preserved; only the magnitude is bounded.

Gradient spikes are common in transformer training. Attention softmax can saturate, producing near-one-hot distributions, which propagates large gradient norms to earlier parameters via the chain rule.

Global norm clipping is standard practice: max_norm=1.0 is used in the original GPT-2 paper, most subsequent language model work, and many vision transformer papers. It is less commonly necessary for small MLPs on well-conditioned data.

`clip_grad_value_`

clips individual gradient components to rather than the global norm. It does not preserve gradient direction and is less commonly used.

Placing clipping before `backward()`

has no effect (`.grad`

is empty). Placing it after `step()`

clips gradients that have already been applied to weights.

`.grad`

on every parameter and applies the update rule. Moment estimates are updated. Weights change here and only here.

```
1optimiser.step()                       # [7]
```

Parameter values change in `step()`

and nowhere else. For plain SGD with momentum:

Adam maintains per-parameter estimates of the first moment (gradient mean) and second moment (gradient variance):

Both estimates are biased toward zero at initialisation (they start at zero, and early values of and underestimate the true moments). Bias correction accounts for this:

The weight update is then:

Default hyperparameters: , , . The per-parameter adaptive learning rate makes Adam robust to varying gradient scales across parameters, which matters in deep networks where early and late layers often have gradients of very different magnitudes.

**AdamW** decouples weight decay from the gradient update. Standard Adam with L2 regularisation adds a gradient term before the update, which then gets scaled by the adaptive step size. AdamW applies weight decay directly to the weights: . The two are not equivalent; AdamW gives more predictable effective regularisation and is now standard for transformer training.

** fused=True** (CUDA only): fuses the entire parameter update into a single CUDA kernel per parameter group, avoiding the Python loop and multiple kernel launches. It runs approximately 30–50% faster than the default at large parameter counts.

** foreach=True**: uses batched

`torch._foreach_*`

operations that process all parameters together in vectorised Python loops. It is intermediate between the default and fused in speed, and available on CPU.The optimiser stores moment state for every parameter. For Adam on a 7B-parameter model, that is 7B float32 tensors for and 7B for , roughly 56GB of optimiser state at full precision. At that scale, training typically uses 8-bit optimisers (bitsandbytes) or shards state across devices (ZeRO).

`optimiser.step()`

, outside the batch loop.

```
1    scheduler.step()                   # [8]
```

The scheduler modifies the `lr`

field of each parameter group in the optimiser. The most common mistake is calling it inside the batch loop:

```
1# wrong — lr decays len(loader) times per epoch instead of once
2for X_batch, y_batch in loader:
3    optimiser.step()
4    scheduler.step()
5
6# correct
7for X_batch, y_batch in loader:
8    optimiser.step()
9scheduler.step()
```

**Cosine annealing** decays the learning rate from `eta_max`

to `eta_min`

over `T_max`

epochs:

In modern large model training, the standard schedule combines a linear warmup with cosine decay. A warmup phase (typically 1–5% of total steps) limits the update magnitude while parameters are far from their trained values. Cosine decay reduces the rate for the remaining steps.

```
1# linear warmup + cosine decay
2# (using HuggingFace's implementation as reference)
3from transformers import get_cosine_schedule_with_warmup
4
5scheduler = get_cosine_schedule_with_warmup(
6    optimiser,
7    num_warmup_steps=100,
8    num_training_steps=10_000,
9)
10scheduler.step()   # called per step, not per epoch, with this scheduler
```

`ReduceLROnPlateau`

is the exception to the no-argument rule. It takes a metric value and reduces the learning rate only when the metric has stopped improving for `patience`

epochs. It must be called with the validation loss: `scheduler.step(val_loss)`

.

`model.eval()`

changes layer behaviour. `torch.no_grad()`

stops graph construction. They are independent operations; you need both for validation.

```
1model.eval()                           # [9]
2with torch.no_grad():                  # [10]
3    val_logits = model(X_val)
4    val_loss   = criterion(val_logits, y_val)
```

`model.eval()`

and `torch.no_grad()`

are independent operations.

`model.eval()`

sets `self.training = False`

on every module. **BatchNorm** switches from computing batch statistics to using its stored `running_mean`

and `running_var`

. **Dropout** switches from sampling Bernoulli masks to the identity function. No other standard layers are affected.

`torch.no_grad()`

disables the construction of the autograd graph entirely. No `grad_fn`

is attached to tensors produced inside the context; no intermediate activations are saved. This roughly halves memory usage compared to a training-mode forward pass and makes computation slightly faster (fewer bookkeeping operations per op).

The two are independent choices:

** torch.inference_mode()** is a stricter form of

`no_grad()`

. Tensors created inside the context have `is_inference() == True`

, which prevents them from participating in a backward pass even if they escape the context manager. It removes a few additional checks in the autograd engine and runs about 10% faster than `no_grad()`

.Track training and validation metrics using `loss.item()`

, not `loss`

. Calling `.item()`

extracts a Python float and detaches from the graph; holding a reference to the tensor keeps the full backward graph alive until the next call.

```
1for epoch in range(NUM_EPOCHS):
2    model.train()
3    running_loss = 0.0
4    for X_batch, y_batch in loader:
5        optimiser.zero_grad()
6        loss = criterion(model(X_batch), y_batch)
7        loss.backward()
8        optimiser.step()
9        running_loss += loss.item()
10    train_loss = running_loss / len(loader)
11
12    model.eval()
13    with torch.no_grad():
14        val_logits = model(X_val)
15        val_loss   = criterion(val_logits, y_val).item()
16        val_acc    = (val_logits.argmax(1) == y_val).float().mean().item()
17
18    print(f'epoch {epoch:3d}  train {train_loss:.4f}  val {val_loss:.4f}  acc {val_acc:.3f}')
```

`torch.save`

writes a checkpoint file; `torch.load`

reads it back. Checkpoints allow a training run to survive crashes and preemption.

```
1# save
2torch.save({
3    'epoch':                epoch,
4    'model_state_dict':     model.state_dict(),
5    'optimiser_state_dict': optimiser.state_dict(),
6    'scheduler_state_dict': scheduler.state_dict(),
7    'val_loss':             val_loss,
8}, 'checkpoint.pt')
9
10# resume
11checkpoint = torch.load('checkpoint.pt', map_location=device)
12model.load_state_dict(checkpoint['model_state_dict'])
13optimiser.load_state_dict(checkpoint['optimiser_state_dict'])
14scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
15start_epoch = checkpoint['epoch'] + 1
```

`state_dict()`

returns an ordered dictionary of parameter and buffer tensors. It does not include the class definition, so the model class must be defined in scope before calling `load_state_dict`

.

The optimiser state must be saved alongside the model to resume training correctly. For Adam, the state includes per-parameter moment estimates and , which take several hundred steps to warm up from zero. Resuming without them is equivalent to restarting the optimiser cold, which typically produces a loss spike at the resume point.

`map_location=device`

handles the common case where the checkpoint was saved on a different GPU than the one loading it.

The standard pattern saves only when validation loss improves, so the saved weights correspond to the best-generalising epoch rather than the final one:

```
1best_val_loss = float('inf')
2
3for epoch in range(NUM_EPOCHS):
4    # ... training loop ...
5
6    model.eval()
7    with torch.no_grad():
8        val_loss = criterion(model(X_val), y_val).item()
9
10    if val_loss < best_val_loss:
11        best_val_loss = val_loss
12        torch.save(model.state_dict(), 'best_model.pt')
```

To restore the best model after training: `model.load_state_dict(torch.load('best_model.pt', map_location=device))`

.

While we're working through the nitty-gritty of the training loop, it's worth spending a few minutes thinking about some of the basic GPU optimisation techniques. None of these exactly count as "melting the hardware" but they're almost all entirely free speed.

Put the model and data on the *same* GPU, minimises the data transfer overhead.

```
1device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2
3model = MLP(in_features=2, hidden=128, out_features=3).to(device)
4# construct optimiser AFTER moving model — it captures parameter references
5optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
6
7# per-batch: move data to device
8for X_batch, y_batch in loader:
9    X_batch = X_batch.to(device, non_blocking=True)
10    y_batch = y_batch.to(device, non_blocking=True)
```

`non_blocking=True`

initiates the host-to-device transfer asynchronously and returns immediately. The GPU can begin work from a previous batch while the new batch is still transferring. The overlap only helps when `pin_memory=True`

in the DataLoader; unpinned memory cannot be transferred asynchronously.

Modern GPUs have dedicated tensor cores that execute float16 (and bfloat16) matrix multiplications 4–8× faster than float32 on the same hardware. Mixed precision training keeps weights in float32 but runs the forward and backward passes in float16.

This makes the computation more efficient, but float16 has a much smaller dynamic range than float32, so small gradient values underflow to zero, producing incorrect updates. The compromise is to maintain the weights in float32, where accumulated rounding errors remain within the float32 dynamic range.

```
1scaler = torch.amp.GradScaler('cuda')
2
3for X_batch, y_batch in loader:
4    optimiser.zero_grad()
5
6    with torch.amp.autocast('cuda', dtype=torch.float16):
7        logits = model(X_batch)
8        loss   = criterion(logits, y_batch)
9
10    scaler.scale(loss).backward()           # backward in float16
11    scaler.unscale_(optimiser)              # unscale before clipping
12    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
13    scaler.step(optimiser)                  # step only if no inf/nan
14    scaler.update()                         # adjust scale factor
```

`torch.amp.autocast`

and `torch.amp.GradScaler`

are the current API (PyTorch 2.0+). The older `from torch.cuda.amp import autocast, GradScaler`

still works but is deprecated.

**Why GradScaler:** float16 has a much smaller dynamic range than float32. Small gradient values underflow to zero, producing incorrect updates. GradScaler multiplies the loss by a large scale factor (typically ) before backward, keeping gradient magnitudes in the float16 range. Before the optimiser step, gradients are divided back to their true scale. If an overflow is detected (inf or nan in any gradient), the step is skipped and the scale factor is reduced.

**bfloat16** (`dtype=torch.bfloat16`

) has the same exponent range as float32 but fewer mantissa bits, so it does not underflow and does not need GradScaler. It is the preferred dtype on recent hardware (A100, H100, TPUs).

```
1with torch.amp.autocast('cuda', dtype=torch.bfloat16):
2    logits = model(X_batch)
3    loss   = criterion(logits, y_batch)
4
5loss.backward()                # no scaler needed
6optimiser.step()
```

For GPU training, the bottleneck is often not the GPU but the data pipeline feeding it. The DataLoader with `num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2`

keeps the GPU fed without stalling. `prefetch_factor=2`

(the default) means each worker pre-fetches two batches beyond what has been consumed.

`torch.backends.cudnn.benchmark = True`

runs a short profiling pass on the first batch to determine the fastest convolution algorithm for your specific input shape. Subsequent batches use that cached choice. Do not use this if input shapes vary between batches; the profiling overhead is incurred again for each new shape.

```
1model = torch.compile(model)
```

Available in PyTorch 2.0 and above, `compile`

traces the forward pass and emits optimised Triton kernels via the Inductor backend. The main benefit is operation fusion: adjacent element-wise operations (activation + dropout + residual add, for example) are merged into a single kernel, reducing memory reads and writes. Typical speedups are 10–30% on GPU training, higher for inference workloads and architectures with many small ops.

`mode='max-autotune'`

runs additional profiling to choose the best kernel for each operation. Compilation takes longer, but the resulting model is faster, and the cost amortises over long training runs.

The first forward pass triggers compilation and is slow. `torch._dynamo.reset()`

clears the cache if you need to recompile (e.g., after changing the model structure).

```
1device = torch.device('cuda')
2model  = torch.compile(MLP(...).to(device))
3
4torch.backends.cudnn.benchmark = True
5
6for epoch in range(NUM_EPOCHS):
7    model.train()
8    for X_batch, y_batch in loader:
9        X_batch = X_batch.to(device, non_blocking=True)
10        y_batch = y_batch.to(device, non_blocking=True)
11
12        optimiser.zero_grad(set_to_none=True)
13
14        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
15            logits = model(X_batch)
16            loss   = criterion(logits, y_batch)
17
18        loss.backward()                                        # no scaler: bf16 doesn't underflow
19        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
20        optimiser.step()
21
22    scheduler.step()
23
24    model.eval()
25    with torch.inference_mode():
26        val_logits = model(X_val.to(device))
27        val_loss   = criterion(val_logits, y_val.to(device))
```

The same loop, with every line referenced.

`(N, features)`

. For the spiral dataset, N=2000, features=2.`(N,)`

. Not one-hot: `CrossEntropyLoss`

expects indices directly.`TensorDataset`

pairs inputs and labels by index. `__getitem__`

returns `(X[i], y[i])`

.`DataLoader`

handles batching, shuffling each epoch, and optional parallel prefetching via `num_workers`

.`nn.Module`

subclass. Register submodules in `__init__`

; define the computation in `forward`

.`.to(device)`

moves all parameters and buffers. Do this before constructing the optimiser.`CrossEntropyLoss`

= LogSoftmax + NLLLoss. Pass raw logits, not softmax outputs.`model.parameters()`

supplies the tensors to optimise.`scheduler.step()`

adjusts `lr`

in the optimiser's param groups.`model.train()`

: dropout masks active, batchnorm uses batch statistics.`zero_grad()`

: clears `.grad`

. PyTorch accumulates by default; without this, gradients compound across batches.`grad_fn`

is the entry point for backward.`backward()`

: reverse-mode AD. Populates `.grad`

on every leaf parameter. Weights unchanged.`optimiser.step()`

: reads `.grad`

, applies Adam update, updates moment estimates. Weights change.`scheduler.step()`

: adjusts lr. Once per epoch, after `optimiser.step()`

, outside the batch loop.`model.eval()`

: dropout pass-through, batchnorm uses running statistics.`torch.no_grad()`

(or `inference_mode()`

): no graph construction. Faster, lower memory.

``` python
1import torch
2import torch.nn as nn
3from torch.utils.data import DataLoader, TensorDataset
4
5# ── data ──────────────────────────────────────────────────────────────────────
6X_train = torch.randn(2000, 2)                                     # [1]
7y_train = make_labels(X_train)                                     # [2]
8
9dataset = TensorDataset(X_train, y_train)                          # [3]
10loader  = DataLoader(dataset, batch_size=64, shuffle=True,         # [4]
11                     num_workers=2, pin_memory=True)
12
13# ── model ──────────────────────────────────────────────────────────────────────
14class MLP(nn.Module):                                              # [5]
15    def __init__(self, in_features, hidden, out_features):
16        super().__init__()
17        self.net = nn.Sequential(
18            nn.Linear(in_features, hidden), nn.ReLU(),
19            nn.Linear(hidden, hidden),      nn.ReLU(),
20            nn.Linear(hidden, out_features),
21        )
22    def forward(self, x):
23        return self.net(x)
24
25model = MLP(in_features=2, hidden=128, out_features=3).to(device)  # [6]
26
27# ── loss, optimiser, scheduler ────────────────────────────────────────────────
28criterion = nn.CrossEntropyLoss()                                  # [7]
29optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)          # [8]
30scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(            # [9]
31    optimiser, T_max=100
32)
33
34# ── training loop ─────────────────────────────────────────────────────────────
35for epoch in range(100):
36
37    model.train()                                                   # [10]
38
39    for X_batch, y_batch in loader:
40
41        optimiser.zero_grad()                                       # [11]
42
43        logits = model(X_batch)                                     # [12]
44
45        loss = criterion(logits, y_batch)                           # [13]
46
47        loss.backward()                                             # [14]
48
49        torch.nn.utils.clip_grad_norm_(                             # [15]
50            model.parameters(), max_norm=1.0
51        )
52
53        optimiser.step()                                            # [16]
54
55    scheduler.step()                                                # [17]
56
57    # ── validation ────────────────────────────────────────────────────────────
58    model.eval()                                                    # [18]
59    with torch.no_grad():                                           # [19]
60        val_logits = model(X_val)
61        val_loss   = criterion(val_logits, y_val)
62        val_acc    = (val_logits.argmax(1) == y_val).float().mean()
```


