Introduction
3D image segmentation involves partitioning volumetric data into distinct regions to extract meaningful information such as identifying organs, tumors, etc. With applications ranging from medical diagnosis to industrial inspection and robotics, 3D segmentation plays a pivotal role in understanding complex three-dimensional structures and objects. In this guide, we’ll explore the fundamentals of 3D image segmentation in medical imaging and learn how to leverage the MONAI framework with the UNet architecture for segmentation tasks.
Learning Objectives
- Understand the fundamentals of 3D image segmentation and its significance in medical imaging.
- Explore the architecture and functionalities of the UNet model, a widely used deep learning framework for semantic segmentation tasks.
- Gain familiarity with the MONAI framework and its role in streamlining the development and deployment of deep learning models for medical image analysis.
- Learn the process of preprocessing medical imaging data, including DICOM to NIfTI conversion, and applying MONAI transforms for data augmentation and normalization.
- Master the implementation of a 3D U-Net model for spleen segmentation using MONAI.
This article was published as a part of the Data Science Blogathon.
What is Image Segmentation?
Image segmentation is a fundamental task in computer vision and medical imaging that involves partitioning an image or a volumetric dataset into multiple regions or segments. Let’s break it down.
Input is an image.
It segments the Region of Interest:
- Sets all the pixels belonging to ROI (foreground) to HIGH.
- Sets all the pixels belonging to the background to LOW.
Output is a segmentation mask.
It classifies every pixel of the image to one of the classes i.e. whether it belongs to foreground or background and therefore estimates the Pixel Probabilities.
Understanding UNET
In this next portion, we will cover an in-depth understanding of how the UNet architecture works. We will explore each element that comprises both encoder and decoder segments along with their respective tasks.
UNet employs both a ‘contracting’ and an ‘expansive’ pathway to achieve accurate segmentation. The contracting pathway follows a conventional convolutional network design, where it repetitively applies two 3×3 convolutions followed by ReLU activation and down sampling through 2×2 max pooling with a stride of 2.
This process doubles the number of feature channels with each iteration, effectively capturing the context of the image.
On the other hand, the expansive pathway focuses on precise localization by upscaling existing features and halving the number of channels using a 2×2 convolution (also known as ‘up-convolution’). This is followed by crop-based concatenation and another round of two consecutive 3×3 convolutions, both finalized with ReLU activation.
- Encoder captures the context i.e. what the image contains.
- Decoder enables a precise localization i.e. where the object is.
- Skip connections preserve fine details aiding in the accurate reconstruction of the segmentation map.
3D U-Net for Volumetric Segmentation
The 3D U-Net architecture is quite similar to the UNET. It has an analysis path to the left and a synthesis path to the right.
Each layer in the analysis path contains two 3×3×3 convolutions followed by a ReLU, and then a 2×2×2 max pooling with strides of two in each dimension.
Each layer in the synthesis path consists of an up-convolution of 2×2×2 by strides of two in each dimension, followed by two 3×3×3 convolutions each followed by a ReLU.
Shortcut connections from layers of equal resolution in the analysis path provide the essential high-resolution features of the synthesis path. Additionally, a 1x1x1 convolutional layer in the last layer reduces the number of output channels to match the desired number of labels, typically three in medical imaging tasks. There is a batch normalization layer before each ReLU that contributes to the stability and efficiency of the network’s training process.
What is MONAI?
MONAI (Medical Open Network for AI) is an open-source, community-driven framework designed to facilitate medical image analysis with deep learning. At its core, MONAI provides a rich set of functionalities to facilitate every stage of the medical image analysis pipeline. From data preprocessing and augmentation to model training, evaluation, and deployment, MONAI offers an intuitive workflow designed to streamline the research process.
One of the key strengths of MONAI lies in its extensive library of pre-built components and algorithms, spanning a wide range of tasks such as image transformation, segmentation, registration, and classification.
Here we will be discussing in detail image segmentation particularly speen segmentation using MONAI.
Implementing UNet with MONAI
Here we will be discussing in detail image segmentation particularly speen segmentation using MONAI.
The first step is to install MONAI and load all the necessary libraries. You can install MONAI with ‘pip install monai’ and import the necessary libraries.
from monai.utils import first, set_determinism
from monai.transforms import (
AsDiscrete,
AsDiscreted,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
SaveImaged,
ScaleIntensityRanged,
Spacingd,
Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from tqdm import tqdm
Understanding 3D Medical Imaging Data
When we are talking about 3d image segmentation we deal with the Nifti files. We have a 3d chunk data which is a CT Scan present in Nifti file format. Each slice comprising the 3d chunk of data is called a dicom file. To better understand we can understand CT Scans as videos and each frame of the video are dicom file.
We will be using the spleen dataset that can be found here. http://medicaldecathlon.com/
DICOM (Digital Imaging and Communications in Medicine) files are the standard format for storing medical imaging data, encompassing various modalities such as X-rays, MRI scans, CT scans, and ultrasounds.
This files contain both image data and metadata, including patient information, and acquisition parameters. DICOM groups are collections of DICOM files that are to each other, such as images from the same study, series, or patient. On the other hand, NIfTI (Neuroimaging Informatics Technology Initiative) files are commonly used in neuroimaging for storing volumetric brain imaging data, such as MRI and MRI scans.
Creating DICOM Groups
Creating DICOM groups involves organizing DICOM files based on their attributes. This function creates the dicom folders containing the group of a fixed number of dicom files(slices).
- in_dir: the path to your folders that contain dicom files.
- out_dir: the path where you want to put the converted NIFTI files.
- Number_slices: number of slices that you need for your project and it will create groups with this number.
def create_groups(in_dir, out_dir, number_slices):
for patient in glob(in_dir + '/*'):
patient_name = os.path.basename(os.path.normpath(patient))
# calculate the number of folders each with
# number_slices of dicom files belonging to the same patient
number_folders = int(len(glob(patient + '/*')) / number_slices)
# print(number_folders)
for i in range(number_folders):
output_path = os.path.join(out_dir, patient_name + '_' + str(i))
os.mkdir(output_path)
# Move the slices into a specific folder
dicom_files = glob(patient + '/*')
for j, file in enumerate(dicom_files[i*number_slices:]):
if j == number_slices:
break
shutil.copy(file, output_path)
# create groups of image dicom files
create_groups(dicom_files_image_path, dicom_groups_image_path, number_slices=40)
print("Creating Dicom Groups from Image dicoms completed!!n")
# create groups of label dicom files
create_groups(dicom_files_label_path, dicom_groups_label_path, number_slices=40)
print("Creating Dicom Groups from Label dicoms completed!!n")
While DICOM is widely used in medical imaging, it may not always be the most convenient format for analysis and processing. So the conversion of dicom groups to nifti is required. The conversion process typically involves extracting relevant metadata and pixel data from DICOM files and reformatting them into NIfTI-compatible structures. This can be done using dedicated DICOM to NIfTI conversion.
This function will be used to convert the DICOM folder into NIFTI files after creating the groups with the number of slices that you want.
- in_dir: the path to the folder where you have all the patients (folder of all the groups).
- out_dir: the path to the output where you want to save the converted nifty.
def dcm2nifti(in_dir, out_dir):
print(glob(in_dir + '/*'))
for folder in tqdm(glob(in_dir + '/*')):
print(folder)
patient_name = os.path.basename(os.path.normpath(folder))
print(patient_name)
dicom2nifti.dicom_series_to_nifti(folder, os.path.join(out_dir, patient_name + '.nii.gz'))
dcm2nifti(dicom_groups_image_path, nifti_files_image_path)
print("Conversion from Image Dicom Groups to Nifti files completed!!n")
3D Spleen Segmentation with MONAI
Now that we have the NIFTI files required for the segmentation task, let’s dive into the segmentation task using MONAI.
Preparing Training and Validation Data
The first step is to prepare the data after we get the NIFTI files. This prepares file paths for training and validation data by locating NIFTI images and label files for spleen segmentation.
It creates a list of dictionaries, each containing an image file path and its corresponding label file path, and then splits them into training and validation sets.
data_dir = os.path.join("/content/drive/Spleen-Segmentation/Data/Task09_Spleen")
train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
Monai Transforms
MONAI provides a set of powerful tools designed to preprocess and augment medical imaging data called MONAI transforms. These transforms encompass a wide range of operations, including data normalization, resampling, cropping, and intensity adjustments, tailored specifically for medical imaging applications. By applying MONAI transforms to input data, the quality, consistency, and relevance of their datasets can be enhanced, which helps improve the performance of deep learning models.
Let’s have a look at a few of the transforms.
- LoadImaged: Used to load the images and labels from the NIFTI files.
- ScaleIntensityRanged: Scales the intensity range of the image between the input range (a_min, a_max) and output range (b_min, b_max) and clips the values outside the range.
- CropForegroundd: Crops images and labels to the smallest bounding box, removing all zero borders to focus on the valid body area of the images and labels.
- Orientationd: Orients images and labels based on axcodes – RAS
- R- right to left
- A- Anterior to posterior
- S- Superior to inferior
- Spacingd: Changes pixel spacing of images and labels, adjusting the spacing by pixdim=(1.5, 1.5, 2.)
- RandCropByPosNegLabeld: Randomly crops the samples from a big image based on pos / neg ratio.
- EnsureChannelFirstd: Ensures the original data to construct a ‘channel first’ shape.
Set Up Transforms for Training and Validation
Now that we understand MONAI transforms, let us utilize different Monai transforms for both training and validation data.
The transforms include loading images and labels, ensuring channel-first format, adjusting intensity range, cropping out space, orienting the images, adjusting spacing, and performing random cropping based on positive and negative labels for training. Validation transforms exclude the random cropping for a consistent evaluation.
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
]
)
Check Transforms in the Data Loader
Now let us check the transforms in the Data Loader. We first load a validation dataset, process the first batch, extract an image and its corresponding label, and display a specific slice (at index 80) from both the image and label for visual inspection in a side-by-side plot.
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.axis("off")
plt.show()
Preprocessing Pipeline
Let us visualize a few of the intermediate preprocessing outputs.
Contrast Adjustment and Intensity Scaling
Crop Foreground
Training UNet Model with MONAI
Let us now define a 3D U-Net model for semantic segmentation, utilizing GPU if available. The model architecture consists of contracting and expanding paths with specified channels and strides, enhanced by residual units.
The training setup includes the Dice loss, Adam optimizer, and a Dice metric for evaluation, targeting multi-class segmentation with background excluded.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
Training Loop
The next step is to train a U-Net model for semantic segmentation over multiple epochs, here we are training for 500 epochs, evaluating a validation dataset at intervals. It tracks loss, and dice metrics, and saves checkpoints of the model’s state, optimizer state, and training progress to monitor and resume training later.
max_epochs = 500
val_interval = 2
checkpoint = torch.load("/content/drive/Spleen-Segmentation/ImprovedResults/my_checkpoint.pth.tar")
best_metric = checkpoint["best_metric"]
best_metric_epoch = checkpoint["best_metric_epoch"]
epoch_loss_values = checkpoint["train_loss"]
metric_values = checkpoint["val_dice"]
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])
save_dir = "/content/drive/Spleen-Segmentation/ImprovedResults"
checkpoint = {}
for epoch in range(240, max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in tqdm(train_loader):
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# print(
# f"{step}/{len(train_ds) // train_loader.batch_size}, "
# f"train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(
save_dir, "best_metric_model.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"nbest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
checkpoint["train_loss"] = epoch_loss_values
checkpoint["val_dice"] = metric_values
checkpoint["best_metric_epoch"] = best_metric_epoch
checkpoint["best_metric"] = best_metric
checkpoint["model_state_dict"] = model.state_dict()
checkpoint["optimizer_state_dict"] = optimizer.state_dict()
torch.save(checkpoint, os.path.join(save_dir, "my_checkpoint2.pth.tar"))
Evaluating Model Performance: Metrics and Visualization
Evaluation of the model is a crucial step that measures the agreement between predicted and ground truth segmentations. Here we will be discussing a few evaluation metrics commonly used in image segmentation tasks and carry out visualizations by plotting the loss curves and displaying the outputs.
Evaluation Metrics
Evaluation metrics play a crucial role in assessing the performance and accuracy of 3D image segmentation algorithms. Several metrics are commonly used in evaluating 3D image segmentation. Let’s understand a few of them.
Intersection over Union (IoU)
It computes the ratio of the intersection to the union of the segmented and ground truth regions, offering a normalized measure of overlap.
Dice Similarity Coefficient (DSC)
It measures the overlap between the segmented region and ground truth, providing a comprehensive measure of segmentation accuracy.
Dice Loss = 1 – Dice Score
Visualization: Plot Loss and Metrics
Now we visualize the training and validation performance of a model during training epochs.
val_interval = 2
plt.figure("train", (15, 5))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Dice Loss")
x = [i + 1 for i in range(len(checkpoint["train_loss"]))]
y = checkpoint["train_loss"]
plt.xlabel("#Epochs")
plt.ylabel("Dice Loss")
plt.plot(x, y)
plt.plot(checkpoint["best_metric_epoch"],
checkpoint["train_loss"][checkpoint["best_metric_epoch"]], 'r*', markersize=8)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice Score")
x = [val_interval * (i + 1) for i in range(len(checkpoint["val_dice"]))]
y = checkpoint["val_dice"]
plt.xlabel("#Epochs")
plt.plot(x, y)
plt.plot(checkpoint["best_metric_epoch"],
checkpoint["val_dice"][checkpoint["best_metric_epoch"]//2], 'r*', markersize=10)
plt.annotate("Best Score[470, 0.9516]", xy=(checkpoint["best_metric_epoch"],
checkpoint["val_dice"][checkpoint["best_metric_epoch"]//2]))
plt.savefig("LearningCurves.png")
plt.show()
The left subplot displays the average dice loss per epoch, with a red star indicating the epoch with the best validation metric. The right subplot illustrates the mean dice score at validation intervals, with an annotation marking the epoch with the best validation score.
Result
After this, we load a trained UNet model from a specified directory, perform inference on validation data using the sliding window inference technique, and visualize the input image, ground truth label, and model output for a slice along the z-axis.
save_dir = "/content/drive/Spleen-Segmentation/ImprovedResults/"
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
model.load_state_dict(torch.load(
os.path.join(save_dir, "best_metric_model.pth"), map_location=device))
model.eval()
# elapsed_time = 0
with torch.no_grad():
for i, val_data in enumerate(val_loader):
roi_size = (160, 160, 160)
sw_batch_size = 4
# t = time.time()
val_outputs = sliding_window_inference(
val_data["image"].to(device), roi_size, sw_batch_size, model
)
# elapsed_time += time.time() - t
# print("Elapse Time : ", time.time()-t)
# plot the slice [:, :, 80]
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image {i}")
plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.title(f"label {i}")
plt.axis("off")
plt.imshow(val_data["label"][0, 0, :, :, 80])
plt.subplot(1, 3, 3)
plt.title(f"output {i}")
plt.axis("off")
plt.imshow(torch.argmax(
val_outputs, dim=1).detach().cpu()[0, :, :, 80])
plt.show()
if i == 3:
Break
Let us visualize the results closely.
Close to ground truth
Better than ground truth
Applications and Case Studies of 3D Image Segmentation
3D image segmentation can extract meaningful insights from complex data. By partitioning images into distinct regions or structures, clinicians can accurately identify and analyze anatomical features, abnormalities, and pathologies. The applications of 3D image segmentation span across various medical specialties and clinical scenarios discussed below.
Tumor Detection
Accurate tumor detection through 3D image segmentation aids clinicians in diagnosing malignancies and monitoring disease progression for appropriate treatment planning.
Organ Segmentation
Organ segmentation enables clinicians to assess organ function, identify abnormalities, and plan interventions with higher precision and accuracy.
Treatment Planning
Precise segmentation of anatomical structures supports optimal treatment planning, guiding surgical trajectories, and delivering targeted therapies with minimal damage to healthy tissues.
Case Studies
Let us dive deeper into the case studies:
Brain Tumor Segmentation
In a study published in the Journal of Neurosurgery, researchers utilized 3D image segmentation to delineate tumor boundaries in MRI scans of patients with glioblastoma, a type of malignant brain tumor. Accurate segmentation enabled clinicians to assess tumor size, location, and response to treatment, guiding surgical resection and radiation therapy planning.
Cardiac Segmentation for Treatment Planning
In a clinical case presented at a cardiology conference, 3D segmentation of the heart from cardiac MRI scans facilitated treatment planning for patients with congenital heart defects. Precise segmentation of cardiac structures allowed cardiologists to assess ventricular function, identify abnormalities, and plan surgical interventions or cardiac catheterization procedures with improved accuracy and outcomes.
Liver Segmentation in Transplantation
A retrospective analysis of liver transplantation cases demonstrated the utility of 3D liver segmentation from CT scans in surgical planning and donor-recipient matching. Accurate segmentation of liver anatomy enabled surgeons to assess liver volume, vascular structures, and disease extent, facilitating donor selection, graft optimization, and post-transplant monitoring for improved patient outcomes.
These case studies illustrate how 3D image segmentation contributes to improved clinical workflows, personalized treatment planning, and better patient outcomes across a wide range of medical specialties and conditions.
For more details, visit this GitHub repo. https://github.com/bbabina/Spleen-Segmentation-using-Monai-and-Pytorch
Conclusion
In conclusion, 3D image segmentation, particularly in medical imaging, has revolutionized healthcare by providing clinicians with powerful tools to extract valuable insights from complex data. Through techniques like UNet architecture implemented with MONAI framework, there is the possibility of accurate segmentation of anatomical structures, tumors, and abnormalities aiding in diagnosis, treatment planning, and monitoring. Furthermore, the diverse applications of 3D segmentation highlighted in the case studies underscore its profound impact on clinical workflows and patient outcomes, promising a future where medical imaging continues to drive advancements in personalized healthcare.
Key Takeaways
- 3D image segmentation enhances the ability to accurately identify and analyze anatomical structures and abnormalities, facilitating precise diagnosis and treatment planning.
- The MONAI framework with UNET architecture offers a versatile and efficient platform for medical image analysis, providing a rich set of tools and pre-built components tailored for various tasks, including segmentation.
- From tumor detection to organ segmentation and treatment planning, 3D image segmentation finds applications across various medical specialties.
Frequently Asked Questions
A. 3D image segmentation involves dividing volumetric data into distinct regions, that are crucial for tasks like identifying organs and tumors. It plays a pivotal role in medical diagnosis, treatment planning, and monitoring.
A. The UNet architecture utilizes both contracting and expansive pathways to achieve accurate segmentation. It captures context through convolutional layers and focuses on precise localization by upscaling features. UNet’s skip connections preserve fine details, that aid in the reconstruction of segmentation maps with high accuracy.
A. MONAI offers a set of functionalities that are tailored for medical image analysis, from data preprocessing to model deployment. The library of pre-built components and algorithms simplifies tasks like image transformation, segmentation, registration, and classification.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.
By Analytics Vidhya, March 27, 2024.