In this article, we cover three topics: what to visualize during training, the tools that provide those visualizations, and the methods to capture model computations directly using hooks and breakpoints.
# Introduction #
Training a machine learning model and observing the loss decrease is a feeling of progress, until the validation accuracy reaches a plateau or the loss begins to spike, and you're not sure what caused it. At that point, most people add more logging or start tuning hyperparameters, hoping something changes. What most analysts skip at this stage is actual visibility into what is happening inside the model during training. Visual debugging tools can provide useful insights at this stage.
In this article, we cover three topics: what to visualize during training (gradients, losses, and embeddings), the tools that provide those visualizations (** TensorBoard** and its main alternatives), and the methods to capture model computations directly using hooks and breakpoints.
# Visualizing Gradients, Losses, and Embeddings #
// Loss Curves
When training a model, the loss curve is usually the first thing to check. When both the training loss and validation loss decline and remain close, it indicates that the training is progressing well. When validation loss starts rising while training loss keeps falling, the model is overfitting. When both curves plateau early, the model isn't learning, which typically indicates a problem with the data or learning rate.
In addition, gradient flow is also important. The vanishing gradient problem may manifest in practice if the loss curves decrease smoothly but too slowly, indicating that gradients are too small by the time they reach early layers.
The plot shown below simulates a typical overfitting pattern. Both losses decrease together for the first ten epochs, and then the validation loss starts increasing while the training loss keeps falling.
The red dotted line marks where the divergence begins: in a real run, that's the point to start investigating regularization or early stopping.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
model = nn.Sequential(nn.Linear(16, 16), nn.Tanh(),
nn.Linear(16, 16), nn.Tanh(),
nn.Linear(16, 1))
grad_magnitudes = {}
def grad_hook(name):
def hook(module, grad_input, grad_output):
grad_magnitudes[name] = grad_output[0].abs().mean().item()
return hook
for i, layer in enumerate(model):
layer.register_backward_hook(grad_hook(f"Layer {i}"))
output = model(torch.randn(32, 16))
output.mean().backward()
plt.bar(grad_magnitudes.keys(), grad_magnitudes.values())
plt.title("Mean Gradient Magnitude per Layer")
plt.ylabel("Mean |gradient|")
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()
It outputs:
// Raw Gradient Magnitudes
Layer 4 (Linear): 0.031250
Layer 3 (Tanh): 0.004646
Layer 2 (Linear): 0.004241
Layer 1 (Tanh): 0.002126
Layer 0 (Linear): 0.001631
The chart reads right to left: Layer 4 represents the output layer, and Layer 0 is the first. The output layer gets a gradient of 0.031, but by the time it reaches Layer 0, that number has dropped to 0.0016 β roughly 20 times smaller.
The red bar that appears on each of the first three layers indicates that gradients are already in the risk zone before they ever reach the start of the network. In a real training run on a deeper model, these initial layers would adjust their weights so slowly that they would hardly learn anything.
This is a practical example of the vanishing gradient problem: the early layers are silently undertraining, which can't be seen without this kind of plot.
// Gradient Visualization
Plotting gradient magnitudes layer by layer during training gives a direct view of whether gradients are reaching the early parts of the network with considerable values. In deep models, gradients may vanish as they move backward through layers. The gradient value histograms for each layer, recorded during training, can reveal this pattern and help us identify the issue early on.
** PyTorch**'s
register_backward_hook
function allows us to obtain gradient tensors from any layer without modifying the training loop. We connect a hook to a module, which activates during each backward pass, sending the gradient tensors to a specified callback.The histogram below shows the complete distribution of gradient values for each layer after one backward pass. Each subplot represents a single layer, ordered from the initial layer to the final one.
The code for this can be found ** here**.
What we're looking for in a healthy network is histograms across layers with roughly similar spreads.
If the early layers show a very narrow, spike-like distribution centered tightly on zero, that could be a red flag indicating vanishing gradients.
The gradients still exist, but they're so small they carry almost no learning information. This visualization can help us catch this pattern after the first few batches, rather than after a full training run.
// Embeddings
When a model maps inputs to a learned representation, visualizing that representation tells us whether the model is separating the data as we'd expect. The most common approach is to take the embeddings from a trained (or partially trained) model, reduce their dimensionality using ** t-SNE** or
, and plot them with class labels as colors.
UMAPIf the classes are tight and well-separated, that means the model has learned useful separation. Overlapping classes mean the model hasn't separated the concepts yet. This step is useful for debugging models trained on text or images before adding the final classification layer.
# TensorBoard and Its Alternatives #
// TensorBoard
TensorBoard is your standard starting point. Originally built for ** TensorFlow**, it works with PyTorch through
torch.utils.tensorboard
. Data can be logged through a SummaryWriter
object, and you can view the results in a browser tab. It handles scalars (loss, accuracy), histograms (weight and gradient distributions), images, and an embedding projector for visualizing high-dimensional representations.The main limitation is its locality. Sharing your results with a team means setting up shared storage for log files or using TensorBoard.dev, which has limits on what it supports.
// Weights & Biases
** Weights & Biases** (W&B) is what most machine learning teams use for collaboration or more detailed tracking.
Setup is done with two lines: wandb.init()
at the start of a run and wandb.log()
inside the training loop. Everything syncs to a cloud dashboard automatically, and runs are grouped by project, making experiment comparison straightforward.
Check the code snippet below:
import wandb
wandb.init(project="my-model", config={"lr": 0.001, "epochs": 20, "batch_size": 32})
for epoch in range(wandb.config.epochs):
train_loss = 1 / (1 + 0.3 * epoch) # simulated
val_loss = train_loss + max(0, 0.04 * (epoch - 10)) # simulated
wandb.log({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
wandb.finish()
Once the run finishes, the logged metrics can be viewed in the W&B dashboard, alongside the configuration that produced them. Comparing two runs with different parameters can easily be done by selecting them in the interface, with no manual log parsing needed.
W&B also supports hyperparameter sweeps with built-in visualization, showing which hyperparameters affected the outcome the most.
System metrics like GPU utilization and memory usage are also logged automatically. For teams running many experiments in parallel, the shared workspace removes a lot of the manual overhead of keeping track of what was tried.
// Sacred
** Sacred** takes a different approach. It focuses on reproducibility rather than visualization. We annotate a training script with Sacred's experiment decorator, which records the entire configuration, any changes made during runtime, and all recorded metrics in a database (usually MongoDB). This way, each run and its precise settings turn into a permanent record.
For the visualization part, Sacred pairs with front-ends like Omniboard or Sacredboard. This adds complexity compared to TensorBoard or W&B, but the strength is auditability: any run from the past can be reproduced exactly as it was configured.
// Guild.ai
** Guild.ai** works from the command line and doesn't require you to change the training code. We run a training script through Guild using
guild run train.py
, which records all the logs produced by the script along with any output files, linking them to that particular run. Metrics and run comparisons are available through Guild's command-line interface (CLI) or its local UI.This framework is a good choice when working with existing scripts or third-party code that we prefer not to modify. It provides fewer features than W&B, but the setup cost is also lower.
# Using Breakpoints and Hooks for Machine Learning Computations #
// Forward and Backward Hooks
PyTorch's hook system lets us intercept computations at any point in a model's forward or backward pass. The register_forward_hook
function attaches a callback to any layer, and it fires every time that layer processes a batch. The callback captures the layer's input and output tensors, which we can then log, check for NaN values, or plot.
The register_backward_hook
function does the same for the backward pass, giving us access to the gradient tensors flowing through each layer. Together, these two hooks cover most of what we'd want to inspect during training without modifying the model definition or the training loop.
A practical application is the detection of NaN values. A forward hook that evaluates tensor.isnan().any()
at every layer's output detects numerical instability right away, preventing it from spreading and damaging the rest of the training.
Here's a minimal working example, using a three-layer model with a hook attached to each layer:
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 4))
def nan_hook(layer, input, output):
if output.isnan().any():
print(f"[NaN detected] Layer: {layer.__class__.__name__}")
else:
print(f"[Clean] Layer: {layer.__class__.__name__}, output shape: {tuple(output.shape)}")
for layer in model:
layer.register_forward_hook(nan_hook)
print("--- Normal input ---")
model(torch.randn(2, 8))
print("\n--- Corrupted input ---")
bad_input = torch.randn(2, 8)
bad_input[0, 3] = float('nan')
model(bad_input)
Expected output when run:
--- Normal input ---
[Clean] Layer: Linear, output shape: (2, 16)
[Clean] Layer: ReLU, output shape: (2, 16)
[Clean] Layer: Linear, output shape: (2, 4)
--- Corrupted input ---
[NaN detected] Layer: Linear
[NaN detected] Layer: ReLU
[NaN detected] Layer: Linear
In this example, the hook checks the output tensor after each layer fires and reports whether it's clean or corrupted.
Running it twice β once with normal input and once with a single NaN injected β demonstrates how instability propagates through the network, layer by layer.
// Debugger Breakpoints
Standard Python debuggers work fine inside training loops.
Dropping import pdb; pdb.set_trace()
at any point s execution and brings up an interactive prompt that allows us to examine tensor shapes, verify that data preprocessing hasn't produced unexpected values, and manually step through the forward pass.
Most machine learning development environments β ** VSCode** and
both β let us set breakpoints graphically and inspect tensors in a dedicated pane, offering a quicker alternative to the terminal-based
PyCharmpdb
interface.However, breakpoints are particularly valuable during the initial one or two batches, as we confirm that the data, model, and loss function are working properly before starting a complete training run.
# Conclusion #
Training a model without visualizing what's happening inside means interpreting symptoms rather than the actual causes.
When training a model, whether the loss curve plateaus early, gradients vanish, or embeddings don't separate, without the right instrumentation, none of these factors announce themselves clearly.
The tools covered in this article operate at different levels. Loss curves and gradient histograms give continuous feedback during training, catching problems like overfitting or vanishing gradients before they compound and break your framework.
Embedding visualizations reveal whether the model is learning a good separation from the data. TensorBoard, W&B, Sacred, and Guild.ai each handle the logging and tracking side differently, but they all serve the same purpose: making experiment history searchable and comparable rather than scattered. Finally, hooks and debuggers go one step further and let you and inspect the actual tensors flowing through the network at any layer.
Nonetheless, these tools can't fix a broken model on their own. What they do is shorten the distance between something going wrong and understanding why β which is usually most of the work.
is a data scientist and in product strategy. He's also an adjunct professor teaching analytics, and is the founder of StrataScratch, a platform helping data scientists prepare for their interviews with real interview questions from top companies. Nate writes on the latest trends in the career market, gives interview advice, shares data science projects, and covers everything SQL.