Introduction

Image segmentation is the process of labeling each and every pixel in an image. It’s one of the significant tasks in computer vision, majorly helping in tasks such as biomedical image segmentation, self-driving cars and more.

In this article, we will go on exploring U-Net architecture. Along with this, we will also implement the same using PyTorch with the help of the Carvana data set from Kaggle.

U-Net Architecture

U-Net architecture is wholly built on convolution layers. One thing with all convolution layers is that it’s independent of input image resolution. We can train this model on 224x224 images and then use it on 512x512 images. This is one of the significant advantages of using only convolution layers.

The architecture is as follows:

U-Net

It consists of 3 parts -

  • Encoder

  • Bottleneck

  • Decoder

Arch Breakdown

Encoder

Encoder path is the path where we are going to downsample the image. We will be using max-pooling layers to downsample the image. We will be using 2 convolution layers for each downsampling. The first convolution layer will have 64 filters, and the second will have 128 filters (Of course, the filter size is our wish, but here I am following the paper). We will be using a 3x3 kernel size for all convolution layers. In the above image, each convolution layer in the encoder path is followed by a ReLU activation function.

Bottleneck

This is the transition part of the architecture. Here the input image is passed to the decoder block, where the image is upsampled, and the segmentation map is achieved.

Decoder

Decoder path is the path where we are going to upsample the image. This part of the network entirely consists of transpose convolution layers. To preserve the spatial features from the input, we use skip/residual connections from the same shaped layers from the encoder path.

PyTorch Implementation

We will be using the Carvana dataset from Kaggle. The dataset consists of 5088 images of size 1280x1918. The dataset is already split into train and test sets. We will use the train set for training and the test set for validation.

The Dataset

Kaggle Link

Dataset Preview

We take the last 500 images for validation and the rest for training.

Further, we perform data augmentations and create new data set. We will be using the albumentations library for this. I usually output the augmented images into a folder and scan through them to check if any abnormality exists. I will proceed with the training with the expanded dataset if everything is fine.

Augmentation

The following are the transformations that I have used for the dataset.

def augmentation(image, mask):
    # Create augmentation pipeline

    image = Image.open(image).convert('RGB')
    image = np.array(image)  

    mask = Image.open(mask).convert('L')
    mask = np.array(mask)

    aug = A.HorizontalFlip(p=1.0)
    augmented = aug(image=image, mask=mask)
    i2 = augmented['image']
    m2 = augmented['mask']

    aug = A.VerticalFlip(p=1.0)
    augmented = aug(image=image, mask=mask)
    i3 = augmented['image']
    m3 = augmented['mask']

    aug = A.GridDistortion(p=1.0)
    augmented = aug(image=image, mask=mask)
    i4 = augmented['image']
    m4 = augmented['mask']

    augmented_image_list = [image, i2, i3, i4]
    augmented_mask_list = [mask, m2, m3, m4]

    return augmented_image_list, augmented_mask_list

Augmentation In Process

Dataset Creation

We create our dataset by inheriting the torch.utils.data.Dataset class. We override the __len__ and __getitem__ methods. The __len__ method returns the length of the dataset, and the __getitem__ method returns the image and mask at the given index.

# Filename dataset.py

# Importing libraries
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms

# Class
class image_dataset(Dataset):
    # Init method
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(self.img_dir)

    # Returns the length of the dataset
    def __len__(self):
        return len(self.images)

    # Returns the image and mask at the given index
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(
            self.mask_dir, self.images[idx].replace('.png', '.png'))
        image = np.array(Image.open(img_path).convert('RGB'))
        mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        return image, mask

Now our dataset class is ready to be used in the DataLoader.

The Model

This architecture is an entirely convolutional architecture.

As mentioned previously, we have 3 parts in the architecture - Encoder, Bottleneck and Decoder.

Importing the required libraries.

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

Each module in the architecture is a double convolution block. We will create that here.

class doubleConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(doubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3,stride=1, padding=1, bias=False), # Bias is set to False as we are using batch normalization
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

Now the actual model.

Init method

We create an empty list and append the layers to it using a for loop for the encoder and decoder parts.

The feature size is given as an input param to the class.

Since the input image is RGB, in channels = 3. Output is just binary, so out channels = 1. (Its either car or not car)

class unet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(unet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for feature in features:
            self.encoder.append(doubleConv(in_channels, feature))
            in_channels = feature

        # Decoder
        for feature in reversed(features):
            self.decoder.append(nn.ConvTranspose2d(
                feature * 2, feature, kernel_size=2, stride=2))
            self.decoder.append(doubleConv(feature * 2, feature))

        # bottleneck
        self.middle = doubleConv(features[-1], features[-1]*2)
        self.output_conv = nn.Conv2d(
            features[0], out_channels, kernel_size=1, stride=1)

Forward method

In the original paper, skip connections are present. For that, initially, we create an empty list and append the layers to it.

Encoder step: Since in the init method, we defined encoder as a list, we can loop through it. We pass the input image through each layer and append the output to the skip connection list.

Then the output is passed through the bottleneck layer. At this step, the skip connections layer is reversed. This is because, in the decoder block, we go from the bottleneck layer to the first layer.

Decoder step: We loop through the decoder list, and pass the output through each layer. We also give the skip connection layer through each layer. We concatenate the output of the decoder layer and the skip connection layer. This is the skip connection.

Then finally, we pass the output into the output convolution layer, which will return the segmentation mask.

    def forward(self, x):
        skip_connections = []

        for encode in self.encoder:
            x = encode(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.middle(x)
        skip_connections = skip_connections[::-1]

        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            skip_conn = skip_connections[i//2]

            if x.shape != skip_conn.shape:
                x = TF.resize(x, size=skip_conn.shape[2:])

            cat_skip = torch.cat([skip_conn, x], dim=1)
            x = self.decoder[i+1](cat_skip)

        x = self.output_conv(x)
        return x

A Utility File

Here, we will use a util file for some commonly used tasks.

Imports

import torch
from dataset import image_dataset
from torch.utils.data import DataLoader
import torchvision

Loading and saving checkpoints

def save_checkpoint(state, filename="checkpoint.pth"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

Data loaders

def get_loaders(train_dir, train_mask_dir, val_dir, val_mask_dir, batch_size, num_workers=4, pin_memory=True, transform=None):
    train_ds = image_dataset(train_dir, train_mask_dir, transform=transform)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
        
    )

    val_ds = image_dataset(val_dir, val_mask_dir, transform=transform)
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    return train_loader, val_loader

Checking Accuracy.

For Image Segmentation, we use Die score.

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

Function to save batch-wise predictions

def save_predictions(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

Training

Imports

import torch
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.optim as optim
from unet import unet as UNET
import albumentations as A 
from albumentations.pytorch import ToTensorV2

from utils import (
    check_accuracy,
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    save_predictions
)

Hyperparameters

# HyperParams
LEARNING_RATE = 1e-4 # You can use LrOnPlateau as well
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 20
NUM_WORKERS = 1
PIN_MEMORY = False
LOAD_MODEL = False
TRAIN_IMG_DIR = "./dataset/augmented_train_images"
TRAIN_MASK_DIR = "./dataset/augmented_train_masks"
VAL_IMG_DIR = "./dataset/augmented_test_images"
VAL_MASK_DIR = "./dataset/augmented_test_masks"
image_size = (320,480)

Needed Transforms

# Transormations

transform = A.Compose(
    [
        A.Resize(height=image_size[0], width=image_size[1]),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

Initialize the model, loss function, optimizer and DataLoader.

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    NUM_WORKERS,
    PIN_MEMORY,
    transform=transform
)

Training Loop

# Train function

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward, using mixed precision
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

Doing the actual training


# Check initial accuracy
check_accuracy(val_loader, model, device=DEVICE)
# Accuracy might be high because it's all random, and most images are as background. That is why we look at die score

# Using mixed precision training
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    print("Epoch: {}".format(epoch))
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

    # print some examples to a folder
    save_predictions(
        val_loader, model, folder="saved_images/", device=DEVICE
    )

We got around 0.7 as the die score.

Training Output

Sample test generated by save predictions function

Original

picture 4

Its prediction

picture 5

Making it better

A die score of 0.7 could be better. You can see in the predictions that the segmentation could be better. This can be improved in the following ways:

  • Using higher-resolution images
  • Using LrOnPlateau
  • There are ways to use pre-trained ResNet, VGG etc models in the model, instead of encoder and decoder. This will make the model better.
  • Increasing layers

Conclusion

We went through U-Net architecture and implemented the same given in the paper. We also trained the model on a dataset and saw ways to improve the accuracy.

You can find my repo here - https://github.com/SuperSecureHuman/ML-Experiments/tree/main/U-Net-Image_Segmentation

References

Updated: