landmarker

Training and Evaluating Static Heatmap Regression Model for Multi-Instance and Single-Class Landmark Detetection (Plant Centroids)#

TODO

We will go through the following steps:

Open In Colab

Setup environment#

!python -c "import landmarker" || pip install landmarker

Setup imports and variables#

import numpy as np
import torch
from torch.utils.data import DataLoader
from monai.transforms import (Compose, RandAffined, RandGaussianNoised, RandStdShiftIntensityd,
                              RandScaleIntensityd, RandAdjustContrastd, RandHistogramShiftd)
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Loading the dataset#

fn_keys = ('image',)
spatial_transformd = [RandAffined(fn_keys, prob=1,
                        rotate_range=(-np.pi/12, np.pi/12),
                        translate_range=(-10, 10),
                        scale_range=(-0.1, 0.1),
                        shear_range=(-0.1, 0.1)
                        )]

composed_transformd = Compose(spatial_transformd)
from landmarker.heatmap.generator import GaussianHeatmapGenerator
from landmarker.datasets import get_plant_centroids_landmark_datasets

ds_train, ds_test_A, ds_test_B, ds_test_C = get_plant_centroids_landmark_datasets("data", transform = composed_transformd, grayscale=False, store_imgs=False,
                                                                                 dim_img=(512,512))
100%|██████████| 1398/1398 [00:00<00:00, 5752.41it/s]
100%|██████████| 100/100 [00:00<00:00, 5978.11it/s]
100%|██████████| 275/275 [00:00<00:00, 6023.15it/s]
100%|██████████| 234/234 [00:00<00:00, 5655.99it/s]
Reading and extracting masks from 1398 images...
100%|██████████| 1398/1398 [00:04<00:00, 326.38it/s]
Reading and extracting masks from 100 images...
100%|██████████| 100/100 [00:00<00:00, 349.23it/s]
Reading and extracting masks from 275 images...
100%|██████████| 275/275 [00:00<00:00, 342.59it/s]
Reading and extracting masks from 234 images...
100%|██████████| 234/234 [00:00<00:00, 335.88it/s]

Constructing a heatmap generator#

from landmarker.heatmap.generator import GaussianHeatmapGenerator

heatmap_generator = GaussianHeatmapGenerator(
    nb_landmarks=1,
    sigmas=3,
    gamma=None,
    heatmap_size=(512, 512),
    learnable=False, # If True, the heatmap generator will be trainable
    device=device
)

Inspecting the dataset#

from landmarker.visualize import inspection_plot

# Plot the first 3 images from the training set
heatmap_generator.device = "cpu" # because dataset tensors are still on cpu
inspection_plot(ds_train, np.random.randint(0, len(ds_train), 3),
               heatmap_generator)
heatmap_generator.device = device # set the desired device back
/opt/conda/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
  warnings.warn(
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../_images/51d050416f64b909c52219bb812cd150fd9eca68f70f38449e026fff6d777f28.png
# Plot the first 3 images from the test1 set
heatmap_generator.device = "cpu" # because dataset tensors are still on cpu
inspection_plot(ds_test_A, range(3), heatmap_generator)
heatmap_generator.device = device # set the desired device back
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../_images/e70e8da52524f444715f8332f2c54e8473c6717d5a2c61f3807b98a6d676b62e.png

Training and initializing the SpatialConfiguration model#

Initializing the model, optimizer and loss function#

from torch import nn

from landmarker.models import get_model
from landmarker.train import EarlyStopping, SaveBestModel

model = get_model("UNet", in_channels=4, out_channels=1).to(device)

lr = 1e-4
batch_size = 4
epochs = 1000

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)

criterion = nn.BCEWithLogitsLoss()

early_stopping = EarlyStopping(patience=10, verbose=True)
save_best_model = SaveBestModel(verbose=True)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                                          patience=20, verbose=True, cooldown=10)

Setting the data loaders and split training set#

split_lengths = [0.8, 0.2]
ds_train_train, ds_train_val = torch.utils.data.random_split(ds_train, split_lengths)
train_loader = DataLoader(ds_train_train, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(ds_train_val, batch_size=batch_size, shuffle=False, num_workers=0)
test_A_loader = DataLoader(ds_test_A, batch_size=batch_size, shuffle=False, num_workers=0)
test_B_loader = DataLoader(ds_test_B, batch_size=batch_size, shuffle=False, num_workers=0)
test_C_loader = DataLoader(ds_test_C, batch_size=batch_size, shuffle=False, num_workers=0)

Training the model#

from landmarker.heatmap.decoder import heatmap_to_coord, heatmap_to_multiple_coord
from landmarker.metrics import point_error

from torch.nn.functional import sigmoid

from landmarker.metrics.metrics import multi_instance_point_error

def train_epoch(model, heatmap_generator, train_loader, criterion, optimizer, device):
    running_loss = 0
    model.train()
    for i, (images, landmarks, _, _, _, _, _, _) in enumerate(tqdm(train_loader)):
        images = images.to(device)
        landmarks = landmarks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        heatmaps = heatmap_generator(landmarks)
        loss = criterion(outputs, heatmaps)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def val_epoch(model, heatmap_generator, val_loader, criterion, device):
    eval_loss = 0
    model.eval()
    with torch.no_grad():
        for _, (images, landmarks, _, _, _ , _, _,
                _) in enumerate(tqdm(val_loader)):
            images = images.to(device)
            landmarks = landmarks.to(device)
            outputs = model(images)
            heatmaps = heatmap_generator(landmarks)
            loss = criterion(outputs, heatmaps)
            eval_loss += loss.item()
    return eval_loss / len(val_loader)

def train(model, heatmap_generator, train_loader, val_loader, criterion, optimizer, device, epochs=1000):
    for epoch in range(epochs):
        ds_train.transform = composed_transformd
        train_loss = train_epoch(model, heatmap_generator, train_loader, criterion, optimizer, device)
        ds_train.transform = None
        val_loss = val_epoch(model, heatmap_generator, val_loader, criterion, device)
        print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss:.4f} - Val loss: {val_loss:.4f}")
        early_stopping(val_loss)
        save_best_model(val_loss, model)
        lr_scheduler.step(val_loss)
        if early_stopping.early_stop:
            print(f"Loading best model...")
            model.load_state_dict(torch.load(save_best_model.path))
            break
# train(model, heatmap_generator, train_loader, val_loader, criterion, optimizer, device,
#       epochs=epochs)

model.load_state_dict(torch.load("model_static_unet_plant_centroids.pt"))
<All keys matched successfully>

Evaluating the model#

pred_landmarks = []
true_landmarks = []
dim_origs = []
pixel_spacings = []
paddings = []
tp = []
fp = []
fn = []
test_mpe = 0
test_tp = 0
test_fp = 0
test_fn = 0
threshold = 0.05
window = 9
model.eval()
model.to(device)
with torch.no_grad():
    for i, (images, landmarks, affine_matrixs, _, _ ,
        dim_orig, pixel_spacing, padding) in enumerate(tqdm(test_A_loader)):
        images = images
        landmarks = landmarks
        outputs = model(images.to(device)).detach().cpu()
        offset_coords = outputs.shape[1]-landmarks.shape[1]
        pred_landmarks_list, _ = heatmap_to_multiple_coord(sigmoid(outputs), window=window,
                                                           threshold=threshold,
                                                           method="argmax")
        pe_batch, tp_batch, fp_batch, fn_batch, pred_landmarks_torch = multi_instance_point_error(
            true_landmarks=landmarks, pred_landmarks=pred_landmarks_list, dim=(512, 512),
            dim_orig=dim_orig, pixel_spacing=pixel_spacing, padding=padding, reduction="none")
        test_mpe += torch.nanmean(pe_batch).item()
        test_tp += torch.nansum(tp_batch).item()
        test_fp += torch.nansum(fp_batch).item()
        test_fn += torch.nansum(fn_batch).item()
        pred_landmarks.append(pred_landmarks_torch)
        true_landmarks.append(landmarks)
        dim_origs.append(dim_orig)
        pixel_spacings.append(pixel_spacing)
        paddings.append(padding)
        tp.append(tp_batch)
        fp.append(fp_batch)
        fn.append(fn_batch)


test_mpe /= len(test_A_loader)

print(f"Test Mean PE: {test_mpe:.4f}")
print(f"Test TP: {test_tp:.4f}")
print(f"Test FP: {test_fp:.4f}")
print(f"Test FN: {test_fn:.4f}")
100%|██████████| 25/25 [05:55<00:00, 14.22s/it]
Test Mean PE: 13.3300
Test TP: 165.0000
Test FP: 12.0000
Test FN: 235.0000

from landmarker.metrics import sdr

sdr_test = sdr([4, 5, 10, 20], true_landmarks=torch.cat(true_landmarks, axis=0), pred_landmarks=torch.cat(pred_landmarks, axis=0),
               dim=(512, 512), dim_orig=torch.cat(dim_origs, axis=0).int(), pixel_spacing=torch.cat(pixel_spacings, axis=0),
               padding=torch.cat(paddings, axis=0))
for key in sdr_test:
    print(f"SDR for {key}mm: {sdr_test[key]:.4f}")
SDR for 4mm: 6.0000
SDR for 5mm: 8.7500
SDR for 10mm: 21.7500
SDR for 20mm: 39.2500
from landmarker.visualize.utils import prediction_inspect_plot_multi_instance
model.to("cpu")
prediction_inspect_plot_multi_instance(ds_test_A, model, range(3), threshold=threshold, activation=nn.Sigmoid())
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../_images/9c0325474982b4077005b56204a099482472628849794e3fa2f4167f95dc0ffb.png
from landmarker.visualize import plot_cpe

plot_cpe(torch.cat(true_landmarks, axis=0), torch.cat(pred_landmarks, axis=0), dim=(512, 512),
            dim_orig=torch.cat(dim_origs, axis=0).int(), pixel_spacing=torch.cat(pixel_spacings, axis=0),
            padding=torch.cat(paddings, axis=0), class_names=ds_test_A.class_names,
            group=False, title="CPE curve", save_path=None,
            stat='proportion', unit='pixels', kind='ecdf')
../../_images/3e8e533522f5c00517eabf9867dc688f96174fd874c28d60282a2f9b806be73d.png
from landmarker.visualize.evaluation import multi_instance_detection_report

multi_instance_detection_report(torch.cat(true_landmarks, axis=0), torch.cat(pred_landmarks, axis=0),
                                torch.cat(tp, axis=0), torch.cat(fp, axis=0), torch.cat(fn, axis=0), dim=(512, 512),
                                dim_orig=torch.cat(dim_origs, axis=0).int(), pixel_spacing=torch.cat(pixel_spacings, axis=0),
                                padding=torch.cat(paddings, axis=0), class_names=ds_test_A.class_names)
Detection report:
1# Instance detection statistics:
============================================================
Class               TP        FP        FN        TPR       
------------------------------------------------------------
landmark_0          165.0     12.0      235.0     0.41      
============================================================

2# Point-to-point error (PE) statistics:
======================================================================
Class               Mean      Median    Std       Min       Max       
----------------------------------------------------------------------
landmark_0          15.08     9.61      51.79     0.72      637.80    
======================================================================

3# Success detection rate (SDR):
================================================================================
Class               SDR (PE≤2mm)   SDR (PE≤2.5mm) SDR (PE≤3mm)   SDR (PE≤4mm)   
--------------------------------------------------------------------------------
landmark_0          1.50           2.50           4.50           6.00           
================================================================================