cd /news/computer-vision/a-coding-implementation-on-monai-for… · home topics computer-vision article
[ARTICLE · art-24952] src=marktechpost.com pub= topic=computer-vision verified=true sentiment=↑ positive

A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes

MONAI, an open-source medical imaging framework, has released a tutorial demonstrating an end-to-end 3D spleen segmentation pipeline using a UNet model on CT volumes from the Medical Segmentation Decathlon Task09 dataset. The implementation applies medical imaging transformations including orientation alignment, voxel-spacing normalization, and patch-based sampling, then trains a 3D UNet with mixed precision and DiceCE loss for binary organ segmentation. The pipeline provides a complete workflow from raw medical volumes through training and validation to qualitative visualization of model predictions against ground-truth masks.

read6 min publishedJun 12, 2026

In this tutorial, we build an end-to-end 3D medical image segmentation pipeline using MONAI to segment the spleen on the Medical Segmentation Decathlon Task09 dataset. We work with volumetric CT scans, apply medical imaging transformations such as orientation alignment, voxel-spacing normalization, intensity windowing, foreground cropping, and patch-based sampling, and then train a 3D UNet model for binary organ segmentation. We also use mixed precision training, DiceCE loss, sliding-window inference, Dice-based validation, and qualitative visualization to understand how the model learns and how its predictions compare with the ground-truth masks. Also, we move from raw medical volumes to a complete train–validate–visualize segmentation system.

!pip install -q "monai[nibabel,tqdm,matplotlib]==1.5.2" 2>/dev/null
import os, time, glob, tempfile, warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from monai.apps import DecathlonDataset
from monai.data import Data, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.transforms import (
   Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
   Spacingd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
   RandFlipd, RandRotate90d, RandShiftIntensityd, AsDiscrete,
)
warnings.filterwarnings("ignore")

We start by installing MONAI with the required medical imaging and visualization dependencies. We then import PyTorch, NumPy, Matplotlib, and the main MONAI modules needed for datasets, transforms, model training, metrics, and inference. We also suppress warnings to keep the notebook output clean while we focus on the segmentation workflow.

QUICK_RUN   = True
device      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir    = tempfile.mkdtemp()
roi_size    = (96, 96, 96)
num_samples = 4
batch_size  = 2
max_epochs  = 15 if QUICK_RUN else 200
val_every   = 3
train_cache = 8 if QUICK_RUN else 24
val_cache   = 2 if QUICK_RUN else 6
set_determinism(seed=0)
print(f"Device: {device} | epochs: {max_epochs} | data dir: {root_dir}")
train_transforms = Compose(common + [
       image_key="image", image_threshold=0),
   RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
   RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
   RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
   RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
   RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.5),
   EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose(common + [EnsureTyped(keys=["image", "label"])])

We define the main configuration for the tutorial, including the device, dataset directory, patch size, batch size, number of epochs, and cache settings. We then create the preprocessing pipeline for CT volumes by images, aligning orientation, resampling voxel spacing, scaling intensities, and cropping the foreground. We also define the training and validation transforms, with the training pipeline including random crops, flips, rotations, and intensity shifts.

train_ds = DecathlonDataset(
   root_dir=root_dir, task="Task09_Spleen", section="training",
   transform=train_transforms, download=True, val_frac=0.2,
   cache_num=train_cache, num_workers=2, seed=0)
val_ds = DecathlonDataset(
   root_dir=root_dir, task="Task09_Spleen", section="validation",
   transform=val_transforms, download=False, val_frac=0.2,
   cache_num=val_cache, num_workers=2, seed=0)
train_ = Data(train_ds, batch_size=batch_size, shuffle=True,
                         num_workers=2, pin_memory=torch.cuda.is_available())
val_   = Data(val_ds, batch_size=1, shuffle=False,
                         num_workers=1, pin_memory=torch.cuda.is_available())
print(f"Train volumes: {len(train_ds)} | Val volumes: {len(val_ds)}")
loss_fn   = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
scaler    = GradScaler("cuda", enabled=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=False, reduction="mean")
post_pred   = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label  = Compose([AsDiscrete(to_onehot=2)])

We load the Medical Segmentation Decathlon Task09 Spleen dataset using MONAI’s DecathlonDataset. We split the data into training and validation sections, apply the appropriate transforms, and wrap both datasets with PyTorch-style data s. We then create a 3D UNet model, define the DiceCE loss, set up the AdamW optimizer, learning-rate scheduler, mixed-precision scaler, Dice metric, and post-processing steps.

best_dice, best_epoch = -1.0, -1
loss_hist, dice_hist, dice_epochs = [], [], []
best_path = os.path.join(root_dir, "best_spleen_unet.pth")
for epoch in range(1, max_epochs + 1):
   model.train(); epoch_loss, t0 = 0.0, time.time()
   for batch in train_:
       x, y = batch["image"].to(device), batch["label"].to(device)
       optimizer.zero_grad(set_to_none=True)
       with autocast("cuda", enabled=torch.cuda.is_available()):
           logits = model(x)
           loss = loss_fn(logits, y)
       scaler.scale(loss).backward()
       scaler.step(optimizer); scaler.update()
       epoch_loss += loss.item()
   scheduler.step()
   epoch_loss /= len(train_); loss_hist.append(epoch_loss)
   print(f"[{epoch:3d}/{max_epochs}] loss={epoch_loss:.4f} "
         f"lr={scheduler.get_last_lr()[0]:.2e} ({time.time()-t0:.0f}s)")
   if epoch % val_every == 0 or epoch == max_epochs:
       model.eval(); dice_metric.reset()
       with torch.no_grad():
           for vb in val_:
               vx, vy = vb["image"].to(device), vb["label"].to(device)
               with autocast("cuda", enabled=torch.cuda.is_available()):
                   vout = sliding_window_inference(vx, roi_size, 4, model,
                                                   overlap=0.5)
               vout = [post_pred(o)  for o in decollate_batch(vout)]
               vlab = [post_label(o) for o in decollate_batch(vy)]
               dice_metric(y_pred=vout, y=vlab)
       d = dice_metric.aggregate().item()
       dice_hist.append(d); dice_epochs.append(epoch)
       if d > best_dice:
           best_dice, best_epoch = d, epoch
           torch.save(model.state_dict(), best_path)
       print(f"        >> val Dice={d:.4f} (best={best_dice:.4f} @ {best_epoch})")
print(f"\nDone. Best mean Dice {best_dice:.4f} at epoch {best_epoch}.")

We run the full training loop, where each epoch trains the 3D UNet on cropped volumetric patches from the spleen dataset. We use automatic mixed precision to reduce memory usage and speed up training when a GPU is available. We also validate the model at regular intervals using sliding-window inference, track the Dice score, and save the best-performing checkpoint.

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(1, len(loss_hist)+1), loss_hist, "-o", ms=3)
ax[0].set(title="Training loss", xlabel="epoch", ylabel="DiceCE loss")
ax[1].plot(dice_epochs, dice_hist, "-o", color="seagreen", ms=4)
ax[1].set(title="Validation mean Dice", xlabel="epoch", ylabel="Dice"); ax[1].set_ylim(0, 1)
plt.tight_layout(); plt.show()
model.load_state_dict(torch.load(best_path, map_location=device)); model.eval()
with torch.no_grad():
   sample = next(iter(val_))
   img = sample["image"].to(device)
   with autocast("cuda", enabled=torch.cuda.is_available()):
       pred = sliding_window_inference(img, roi_size, 4, model, overlap=0.5)
   pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
   img_np, lab_np = img.cpu().numpy()[0, 0], sample["label"].numpy()[0, 0]
   z = int(np.argmax(lab_np.sum(axis=(0, 1))))
fig, ax = plt.subplots(1, 3, figsize=(13, 5))
ax[0].imshow(img_np[:, :, z], cmap="gray");                  ax[0].set_title("CT slice")
ax[1].imshow(lab_np[:, :, z], cmap="viridis");               ax[1].set_title("Ground truth")
ax[2].imshow(pred[:, :, z], cmap="viridis");                 ax[2].set_title("Prediction")
for a in ax: a.axis("off")
plt.tight_layout(); plt.show()

We first plot the training loss and validation Dice score to see how the model improves over time. We then reload the best-saved model checkpoint and run inference on a single validation volume using sliding-window prediction. We visualize the CT slice, ground-truth mask, and predicted segmentation side by side to inspect the model’s qualitative performance.

In conclusion, we finished a practical MONAI-based workflow for 3D spleen segmentation using a 3D UNet model. We prepared the Medical Segmentation Decathlon dataset, transformed and augmented the CT volumes, trained the model with DiceCE loss, validated it using sliding-window inference, and tracked both loss and Dice score over time. We also inspected the final prediction visually by comparing the CT slice, ground-truth label, and model output side by side. Now, we have a clear understanding of how MONAI supports medical segmentation tasks from data and preprocessing to model training, evaluation, checkpointing, and qualitative analysis.

Check out the ** Full Codes with Notebook**. Also, feel free to follow us on

and don’t forget to join ourTwitter

and Subscribe to

150k+ ML SubReddit. Wait! are you on telegram?

our Newsletter

now you can join us on telegram as well.Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us

Sana Hassan, a consulting intern at Marktechpost and dual-degree student at IIT Madras, is passionate about applying technology and AI to address real-world challenges. With a keen interest in solving practical problems, he brings a fresh perspective to the intersection of AI and real-life solutions.

  • Sana Hassan
  • Sana Hassan
  • Sana Hassan
  • Sana Hassan
── more in #computer-vision 4 stories · sorted by recency
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/a-coding-implementat…] indexed:0 read:6min 2026-06-12 ·