U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (2024)

Click here to download the source code to this post

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (1)

In today’s tutorial, we will be looking at image segmentation and building our own segmentation model from scratch, based on the popular U-Net architecture.

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (2)

This lesson is the last of a 3-part series on Advanced PyTorch Techniques:

  1. Training a DCGAN in PyTorch (the tutorial 2 weeks ago)
  2. Training an Object Detector from Scratch in PyTorch (last week’s lesson)
  3. U-Net: Training Image Segmentation Models in PyTorch (today’s tutorial)

The computer vision community has devised various tasks, such as image classification, object detection, localization, etc., for understanding images and their content. These tasks give us a high-level understanding of the object class and its location in the image.

In Image Segmentation, we go a step further and ask our model to classify each pixel in our image to the object category it represents. This can be viewed as pixel-level image classification and is a much harder task than simple image classification, detection, or localization. Our model must automatically determine all objects and their precise location and boundaries at a pixel level in the image.

Thus image segmentation provides an intricate understanding of the image and is widely used in medical imaging, autonomous driving, robotic manipulation, etc.

To learn how to train a U-Net-based segmentation model in PyTorch, just keep reading.

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (3)

Looking for the source code to this post?

Jump Right To The Downloads Section

U-Net: Training Image Segmentation Models in PyTorch

Throughout this tutorial, we will be looking at image segmentation and building and training a segmentation model in PyTorch. We will focus on a very successful architecture, U-Net, which was originally proposed for medical image segmentation. Furthermore, we will understand the salient features of the U-Net model, which make it an apt choice for the task of image segmentation.

Specifically, we will discuss the following, in detail, in this tutorial:

  • The architectural details of U-Net that make it a powerful segmentation model
  • Creating a custom PyTorch Dataset for our image segmentation task
  • Training the U-Net segmentation model from scratch
  • Making predictions on novel images with our trained U-Net model

U-Net Architecture Overview

The U-Net architecture (see Figure 1) follows an encoder-decoder cascade structure, where the encoder gradually compresses information into a lower-dimensional representation. Then the decoder decodes this information back to the original image dimension. Owing to this, the architecture gets an overall U-shape, which leads to the name U-Net.

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (4)

In addition to this, one of the salient features of the U-Net architecture is the skip connections (shown with grey arrows in Figure 1), which enable the flow of information from the encoder side to the decoder side, enabling the model to make better predictions.

Specifically, as we go deeper, the encoder processes information at higher levels of abstraction. This simply means that at the initial layers, the feature maps of the encoder capture low-level details about object texture and edges, and as we gradually go deeper, the features capture high-level information about object shapes and categories.

It is worth noting that to segment objects in an image, both low-level and high-level information is important. For example, a change in texture between objects and edge information can help determine the boundaries of various objects. On the other hand, high-level information about the class to which an object shape belongs can help segment corresponding pixels to correct object classes they represent.

Thus, to use both these pieces of information during predictions, the U-Net architecture implements skip connections between the encoder and decoder. This enables us to take intermediate feature map information from various depths on the encoder side and concatenate it at the decoder side to process and facilitate better predictions.

We will look at the U-Net model in further detail and build it from scratch in PyTorch later in this tutorial.

Our TGS Salt Segmentation Dataset

For this tutorial, we will use the TGS Salt Segmentation dataset. The dataset was introduced as part of the TGS Salt Identification Challenge on Kaggle.

Practically, it is difficult to accurately identify the location of salt deposits from images even with the help of human experts. Therefore, the challenge required participants to help experts precisely identify the locations of salt deposits from seismic images of the earth sub-surface. This is practically important since incorrect estimates of salt presence can lead companies to set up drillers at the wrong locations for mining, leading to a waste of time and resources.

We use a sub-part of this dataset which comprises 4000 images of size 101×101 pixels, taken from various locations on earth. Here, each pixel corresponds to either salt deposit or sediment. In addition to images, we are also provided with the ground-truth pixel-level segmentation masks of the same dimension as the image (see Figure 2).

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (5)

The white pixels in the masks represent salt deposits, and the black pixels represent sediment. We aim to correctly predict the pixels that correspond to salt deposits in the images. Thus, we have a binary classification problem where we have to classify each pixel into one of the two classes, Class 1: Salt or Class 2: Not Salt (or, in other words, sediment).

Configuring Your Development Environment

To follow this guide, you need to have the PyTorch deep learning library, matplotlib, OpenCV, imutils, scikit-learn, and tqdm packages installed on your system.

Luckily, these packages are extremely easy to install using pip:

$ pip install torch torchvision$ pip install matplotlib$ pip install opencv-contrib-python$ pip install imutils$ pip install scikit-learn$ pip install tqdm

If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation — PyTorch’s documentation is comprehensive and will have you up and running quickly.

Having Problems Configuring Your Development Environment?

All that said, are you:

  • Short on time?
  • Learning on your employer’s administratively locked system?
  • Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
  • Ready to run the code right now on your Windows, macOS, or Linux system?

Then join PyImageSearch University today!

Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.

And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!

Project Structure

We first need to review our project directory structure.

Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.

From there, take a look at the directory structure:

.├── dataset│ └── train├── output├── pyimagesearch│ ├── config.py│ ├── dataset.py│ └── model.py├── predict.py└── train.py

The dataset folder stores the TGS Salt Segmentation dataset we will use for training our segmentation model.

Furthermore, we will be storing our trained model and training loss plots in the output folder.

The config.py file in the pyimagesearch folder stores our code’s parameters, initial settings, and configurations.

On the other hand, the dataset.py file consists of our custom segmentation dataset class, and the model.py file contains the definition of our U-Net model.

Finally, our model training and prediction codes are defined in train.py and predict.py files, respectively.

Creating Our Configuration File

We start by discussing the config.py file, which stores configurations and parameter settings used in the tutorial.

# import the necessary packagesimport torchimport os# base path of the datasetDATASET_PATH = os.path.join("dataset", "train")# define the path to the images and masks datasetIMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "images")MASK_DATASET_PATH = os.path.join(DATASET_PATH, "masks")# define the test splitTEST_SPLIT = 0.15# determine the device to be used for training and evaluationDEVICE = "cuda" if torch.cuda.is_available() else "cpu"# determine if we will be pinning memory during data loadingPIN_MEMORY = True if DEVICE == "cuda" else False

We start by importing the necessary packages on Lines 2 and 3. Then, we define the path for our dataset (i.e., DATASET_PATH) on Line 6 and the paths for images and masks within the dataset folder (i.e., IMAGE_DATASET_PATH and MASK_DATASET_PATH) on Lines 9 and 10.

On Line 13, we define the fraction of the dataset we will keep aside for the test set. Then, on Line 16, we define the DEVICE parameter, which determines based on availability, whether we will be using a GPU or CPU for training our segmentation model. In this case, we are using a CUDA-enabled GPU device, and we set the PIN_MEMORY parameter to True on Line 19.

# define the number of channels in the input, number of classes,# and number of levels in the U-Net modelNUM_CHANNELS = 1NUM_CLASSES = 1NUM_LEVELS = 3# initialize learning rate, number of epochs to train for, and the# batch sizeINIT_LR = 0.001NUM_EPOCHS = 40BATCH_SIZE = 64# define the input image dimensionsINPUT_IMAGE_WIDTH = 128INPUT_IMAGE_HEIGHT = 128# define threshold to filter weak predictionsTHRESHOLD = 0.5# define the path to the base output directoryBASE_OUTPUT = "output"# define the path to the output serialized model, model training# plot, and testing image pathsMODEL_PATH = os.path.join(BASE_OUTPUT, "unet_tgs_salt.pth")PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])

Next, we define the NUM_CHANNELS, NUM_CLASSES, and NUM_LEVELS parameters on Lines 23-25, which we will discuss in more detail later in the tutorial. Finally, on Lines 29-31, we define the training parameters such as initial learning rate (i.e., INIT_LR), the total number of epochs (i.e., NUM_EPOCHS), and batch size (i.e., BATCH_SIZE).

On Lines 34 and 35, we also define input image dimensions to which our images should be resized for our model to process them. We further define a threshold parameter on Line 38, which will later help us classify the pixels into one of the two classes in our binary classification-based segmentation task.

Finally, we define the path to our output folder (i.e., BASE_OUTPUT) on Line 41 and the corresponding paths to the trained model weights, training plots, and test images within the output folder on Lines 45-47.

Creating Our Custom Segmentation Dataset Class

Now that we have defined our initial configurations and parameters, we are ready to understand the custom dataset class we will be using for our segmentation dataset.

Let’s open the dataset.py file from the pyimagesearch folder in our project directory.

# import the necessary packagesfrom torch.utils.data import Datasetimport cv2class SegmentationDataset(Dataset):def __init__(self, imagePaths, maskPaths, transforms):# store the image and mask filepaths, and augmentation# transformsself.imagePaths = imagePathsself.maskPaths = maskPathsself.transforms = transformsdef __len__(self):# return the number of total samples contained in the datasetreturn len(self.imagePaths)def __getitem__(self, idx):# grab the image path from the current indeximagePath = self.imagePaths[idx]# load the image from disk, swap its channels from BGR to RGB,# and read the associated mask from disk in grayscale modeimage = cv2.imread(imagePath)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)mask = cv2.imread(self.maskPaths[idx], 0)# check to see if we are applying any transformationsif self.transforms is not None:# apply the transformations to both image and its maskimage = self.transforms(image)mask = self.transforms(mask)# return a tuple of the image and its maskreturn (image, mask)

We begin by importing the Dataset class from the torch.utils.data module on Line 2. This is important since all PyTorch datasets must inherit from this base dataset class. Furthermore, on Line 3, we import the OpenCV package, which will enable us to use its image handling functionalities.

We are now ready to define our own custom segmentation dataset. Each PyTorch dataset is required to inherit from Dataset class (Line 5) and should have a __len__ (Lines 13-15) and a __getitem__ (Lines 17-34) method. We discuss each of these methods below.

We start by defining our initializer constructor, that is, the __init__ method on Lines 6-11. The method takes as input the list of image paths (i.e., imagePaths) of our dataset, the corresponding ground-truth masks (i.e., maskPaths), and the set of transformations (i.e., transforms) we want to apply to our input images (Line 6).

On Lines 9-11, we initialize the attributes of our SegmentationDataset class with the parameters input to the __init__ constructor.

Next, we define the __len__ method, which returns the total number of image paths in our dataset, as shown on Line 15.

The task of the __getitem__ method is to take an index as input (Line 17) and returns the corresponding sample from the dataset. On Line 19, we simply grab the image path at the idx index in our list of input image paths. Then, we load the image using OpenCV (Line 23). By default, OpenCV loads an image in the BGR format, which we convert to the RGB format as shown on Line 24. We also load the corresponding ground-truth segmentation mask in grayscale mode on Line 25.

Finally, we check for input transformations that we want to apply to our dataset images (Line 28) and transform both the image and mask with the required transforms on Lines 30 and 31, respectively. This is important since we want our image and ground-truth mask to correspond and have the same dimension. On Line 34, we return the tuple containing the image and its corresponding mask (i.e., (image, mask)) as shown.

This completes the definition of our custom Segmentation dataset. Next, we will discuss the implementation of the U-Net architecture.

Building Our U-Net Model in PyTorch

It is time to look at our U-Net model architecture in detail and build it from scratch in PyTorch.

We open our model.py file from the pyimagesearch folder in our project directory and get started.

# import the necessary packagesfrom . import configfrom torch.nn import ConvTranspose2dfrom torch.nn import Conv2dfrom torch.nn import MaxPool2dfrom torch.nn import Modulefrom torch.nn import ModuleListfrom torch.nn import ReLUfrom torchvision.transforms import CenterCropfrom torch.nn import functional as Fimport torch

On Lines 2-11, we import the necessary layers, modules, and activation functions from PyTorch, which we will use to build our model.

Overall, our U-Net model will consist of an Encoder class and a Decoder class. The encoder will gradually reduce the spatial dimension to compress information. Furthermore, it will increase the number of channels, that is, the number of feature maps at each stage, enabling our model to capture different details or features in our image. On the other hand, the decoder will take the final encoder representation and gradually increase the spatial dimension and reduce the number of channels to finally output a segmentation mask of the same spatial dimension as the input image.

Next, we define a Block module as the building unit of our encoder and decoder architecture. It is worth noting that all models or model sub-parts that we define are required to inherit from the PyTorch Module class, which is the parent class in PyTorch for all neural network modules.

class Block(Module):def __init__(self, inChannels, outChannels):super().__init__()# store the convolution and RELU layersself.conv1 = Conv2d(inChannels, outChannels, 3)self.relu = ReLU()self.conv2 = Conv2d(outChannels, outChannels, 3)def forward(self, x):# apply CONV => RELU => CONV block to the inputs and return itreturn self.conv2(self.relu(self.conv1(x)))

We start by defining our Block class on Lines 13-23. The function of this module is to take an input feature map with the inChannels number of channels, apply two convolution operations with a ReLU activation between them and return the output feature map with the outChannels channels.

The __init__ constructor takes as input two parameters, inChannels and outChannels (Line 14), which determine the number of channels in the input feature map and the output feature map, respectively.

We initialize the two convolution layers (i.e., self.conv1 and self.conv2) and a ReLU activation on Lines 17-19. On Lines 21-23, we define the forward function which takes as input our feature map x, applies self.conv1 => self.relu => self.conv2 sequence of operations and returns the output feature map.

class Encoder(Module):def __init__(self, channels=(3, 16, 32, 64)):super().__init__()# store the encoder blocks and maxpooling layerself.encBlocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])self.pool = MaxPool2d(2)def forward(self, x):# initialize an empty list to store the intermediate outputsblockOutputs = []# loop through the encoder blocksfor block in self.encBlocks:# pass the inputs through the current encoder block, store# the outputs, and then apply maxpooling on the outputx = block(x)blockOutputs.append(x)x = self.pool(x)# return the list containing the intermediate outputsreturn blockOutputs

Next, we define our Encoder class on Lines 25-47. The class constructor (i.e., the __init__ method) takes as input a tuple (i.e., channels) of channel dimensions (Line 26). Note that the first value denotes the number of channels in our input image, and the subsequent numbers gradually double the channel dimension.

We start by initializing a list of blocks for the encoder (i.e., self.encBlocks) with the help of PyTorch’s ModuleList functionality on Lines 29-31. Each Block takes the input channels of the previous block and doubles the channels in the output feature map. We also initialize a MaxPool2d() layer, which reduces the spatial dimension (i.e., height and width) of the feature maps by a factor of 2.

Finally, we define the forward function for our encoder on Lines 34-47. The function takes as input an image x as shown on Line 34. On Line 36, we initialize an empty blockOutputs list, storing the intermediate outputs from the blocks of our encoder. Note that this will enable us to later pass these outputs to that decoder where they can be processed with the decoder feature maps.

On Lines 39-44, we loop through each block in our encoder, process the input feature map through the block (Line 42), and add the output of the block to our blockOutputs list. We then apply the max pool operation on our block output (Line 44). This is done for each block in the encoder.

Finally, we return our blockOutputs list on Line 47.

class Decoder(Module):def __init__(self, channels=(64, 32, 16)):super().__init__()# initialize the number of channels, upsampler blocks, and# decoder blocksself.channels = channelsself.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])self.dec_blocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])def forward(self, x, encFeatures):# loop through the number of channelsfor i in range(len(self.channels) - 1):# pass the inputs through the upsampler blocksx = self.upconvs[i](x)# crop the current features from the encoder blocks,# concatenate them with the current upsampled features,# and pass the concatenated output through the current# decoder blockencFeat = self.crop(encFeatures[i], x)x = torch.cat([x, encFeat], dim=1)x = self.dec_blocks[i](x)# return the final decoder outputreturn xdef crop(self, encFeatures, x):# grab the dimensions of the inputs, and crop the encoder# features to match the dimensions(_, _, H, W) = x.shapeencFeatures = CenterCrop([H, W])(encFeatures)# return the cropped featuresreturn encFeatures

Now we define our Decoder class (Lines 50-87). Similar to the encoder definition, the decoder __init__ method takes as input a tuple (i.e., channels) of channel dimensions (Line 51). Note that the difference here, when compared with the encoder side, is that the channels gradually decrease by a factor of 2 instead of increasing.

We initialize the number of channels on Line 55. Furthermore, on Lines 56-58, we define a list of upsampling blocks (i.e., self.upconvs) that use the ConvTranspose2d layer to upsample the spatial dimension (i.e., height and width) of the feature maps by a factor of 2. In addition, the layer also reduces the number of channels by a factor of 2.

Finally, we initialize a list of blocks for the decoder (i.e., self.dec_Blocks) similar to that on the encoder side.

On Lines 63-75, we define the forward function, which takes as input our feature map x and the list of intermediate outputs from the encoder (i.e., encFeatures). Starting on Line 65, we loop through the number of channels and perform the following operations:

  • First, we upsample the input to our decoder (i.e., x) by passing it through our i-th upsampling block (Line 67)
  • Since we have to concatenate (along the channel dimension) the i-th intermediate feature map from the encoder (i.e., encFeatures[i]) with our current output x from the upsampling block, we need to ensure that the spatial dimensions of encFeatures[i] and x match. To accomplish this, we use our crop function on Line 73.
  • Next, we concatenate our cropped encoder feature maps (i.e., encFeat) with our current upsampled feature map x, along the channel dimension on Line 74
  • Finally, we pass the concatenated output through our i-th decoder block (Line 75)

After the completion of the loop, we return the final decoder output on Line 78.

On Lines 80-87, we define our crop function which takes an intermediate feature map from the encoder (i.e., encFeatures) and a feature map output from the decoder (i.e., x) and spatially crops the former to the dimension of the latter.

To do this, we first grab the spatial dimensions of x (i.e., height H and width W) on Line 83. Then, we crop encFeatures to the spatial dimension [H, W] using the CenterCrop function (Line 84) and finally return the cropped output on Line 87.

Now that we have defined the sub-modules that make up our U-Net model, we are ready to build our U-Net model class.

class UNet(Module):def __init__(self, encChannels=(3, 16, 32, 64), decChannels=(64, 32, 16), nbClasses=1, retainDim=True, outSize=(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)):super().__init__()# initialize the encoder and decoderself.encoder = Encoder(encChannels)self.decoder = Decoder(decChannels)# initialize the regression head and store the class variablesself.head = Conv2d(decChannels[-1], nbClasses, 1)self.retainDim = retainDimself.outSize = outSize

We start by defining the __init__ constructor method (Lines 91-103). It takes the following parameters as input:

  • encChannels: The tuple defines the gradual increase in channel dimension as our input passes through the encoder. We start with 3 channels (i.e., RGB) and subsequently double the number of channels.
  • decChannels: The tuple defines the gradual decrease in channel dimension as our input passes through the decoder. We reduce the channels by a factor of 2 at every step.
  • nbClasses: This defines the number of segmentation classes where we have to classify each pixel. This usually corresponds to the number of channels in our output segmentation map, where we have one channel for each class.
    • Since we are working with two classes (i.e., binary classification), we keep a single channel and use thresholding for classification, as we will discuss later.
  • retainDim: This indicates whether we want to retain the original output dimension.
  • outSize: This determines the spatial dimensions of the output segmentation map. We set this to the same dimension as our input image (i.e., (config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)).

On Lines 97 and 98, we initialize our encoder and decoder networks. Furthermore, we initialize a convolution head through which will later take our decoder output as input and output our segmentation map with nbClasses number of channels (Line 101).

We also initialize the self.retainDim and self.outSize attributes on Lines 102 and 103.

def forward(self, x):# grab the features from the encoderencFeatures = self.encoder(x)# pass the encoder features through decoder making sure that# their dimensions are suited for concatenationdecFeatures = self.decoder(encFeatures[::-1][0],encFeatures[::-1][1:])# pass the decoder features through the regression head to# obtain the segmentation maskmap = self.head(decFeatures)# check to see if we are retaining the original output# dimensions and if so, then resize the output to match themif self.retainDim:map = F.interpolate(map, self.outSize)# return the segmentation mapreturn map

Finally, we are ready to discuss our U-Net model’s forward function (Lines 105-124).

We begin by passing our input x through the encoder. This outputs the list of encoder feature maps (i.e., encFeatures) as shown on Line 107. Note that the encFeatures list contains all the feature maps starting from the first encoder block output to the last, as discussed previously. Therefore, we can reverse the order of feature maps in this list: encFeatures[::-1].

Now the encFeatures[::-1] list contains the feature map outputs in reverse order (i.e., from the last to the first encoder block). Note that this is important since, on the decoder side, we will be utilizing the encoder feature maps starting from the last encoder block output to the first.

Next, we pass the output of the final encoder block (i.e., encFeatures[::-1][0]) and the feature map outputs of all intermediate encoder blocks (i.e., encFeatures[::-1][1:]) to the decoder on Line 111. The output of the decoder is stored as decFeatures.

We pass the decoder output to our convolution head (Line 116) to obtain the segmentation mask.

Finally, we check if the self.retainDim attribute is True (Line 120). If yes, we interpolate the final segmentation map to the output size defined by self.outSize (Line 121). We return our final segmentation map on Line 124.

This completes the implementation of our U-Net model. Next, we will look at the training procedure for our segmentation pipeline.

Training Our Segmentation Model

Now that we have implemented our dataset class and model architecture, we are ready to construct and train our segmentation pipeline in PyTorch. Let’s open the train.py file from our project directory.

Specifically, we will be looking at the following in detail:

  • Structuring the data-loading pipeline
  • Initializing the model and training parameters
  • Defining the training loop
  • Visualizing the training and test loss curves
# USAGE# python train.py# import the necessary packagesfrom pyimagesearch.dataset import SegmentationDatasetfrom pyimagesearch.model import UNetfrom pyimagesearch import configfrom torch.nn import BCEWithLogitsLossfrom torch.optim import Adamfrom torch.utils.data import DataLoaderfrom sklearn.model_selection import train_test_splitfrom torchvision import transformsfrom imutils import pathsfrom tqdm import tqdmimport matplotlib.pyplot as pltimport torchimport timeimport os

We begin by importing our custom-defined SegmentationDataset class and the UNet model on Lines 5 and 6. Next, we import our config file on Line 7.

Since our salt segmentation task is a pixel-level binary classification problem, we will be using binary cross-entropy loss to train our model. On Line 8, we import the binary cross-entropy loss function (i.e., BCEWithLogitsLoss) from the PyTorch nn module. In addition to this, we import the Adam optimizer from the PyTorch optim module, which we will be using to train our network (Line 9).

Next, on Line 11, we import the in-built train_test_split function from the sklearn library, enabling us to split our dataset into training and testing sets. Furthermore, we import the transforms module from torchvision on Line 12 to apply image transformations on our input images.

Finally, we import other useful packages for handling our file system, keeping track of progress during training, timing our training process, and plotting loss curves on Lines 13-18.

Once we have imported all necessary packages, we will load our data and structure the data loading pipeline.

# load the image and mask filepaths in a sorted mannerimagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH)))maskPaths = sorted(list(paths.list_images(config.MASK_DATASET_PATH)))# partition the data into training and testing splits using 85% of# the data for training and the remaining 15% for testingsplit = train_test_split(imagePaths, maskPaths,test_size=config.TEST_SPLIT, random_state=42)# unpack the data split(trainImages, testImages) = split[:2](trainMasks, testMasks) = split[2:]# write the testing image paths to disk so that we can use then# when evaluating/testing our modelprint("[INFO] saving testing image paths...")f = open(config.TEST_PATHS, "w")f.write("\n".join(testImages))f.close()

On Lines 21 and 22, we first define two lists (i.e., imagePaths and maskPaths) that store the paths of all images and their corresponding segmentation masks, respectively.

We then partition our dataset into a training and test set with the help of scikit-learn’s train_test_split on Line 26. Note that this function takes as input a sequence of lists (here, imagePaths and maskPaths) and simultaneously returns the training and test set images and corresponding training and test set masks which we unpack on Lines 30 and 31.

We store the paths in the testImages list in the test folder path defined by config.TEST_PATHS on Line 36.

Now, we are ready to set up our data loading pipeline.

# define transformationstransforms = transforms.Compose([transforms.ToPILImage(), transforms.Resize((config.INPUT_IMAGE_HEIGHT,config.INPUT_IMAGE_WIDTH)),transforms.ToTensor()])# create the train and test datasetstrainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks,transforms=transforms)testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks, transforms=transforms)print(f"[INFO] found {len(trainDS)} examples in the training set...")print(f"[INFO] found {len(testDS)} examples in the test set...")# create the training and test data loaderstrainLoader = DataLoader(trainDS, shuffle=True,batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,num_workers=os.cpu_count())testLoader = DataLoader(testDS, shuffle=False,batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,num_workers=os.cpu_count())

We first define the transformations that we want to apply while loading our input images and consolidate them with the help of the Compose function on Lines 41-44. Our transformations include:

  • ToPILImage(): it enables us to convert our input images to PIL image format. Note that this is necessary since we used OpenCV to load images in our custom dataset, but PyTorch expects the input image samples to be in PIL format.
  • Resize(): allows us to resize our images to a particular input dimension (i.e., config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH) that our model can accept
  • ToTensor(): enables us to convert input images to PyTorch tensors and convert the input PIL Image, which is originally in the range from [0, 255], to [0, 1].

Finally, we pass the train and test images and corresponding masks to our custom SegmentationDataset to create the training dataset (i.e., trainDS) and test dataset (i.e., testDS) on Lines 47-50. Note that we can simply pass the transforms defined on Line 41 to our custom PyTorch dataset to apply these transformations while loading the images automatically.

We can now print the number of samples in trainDS and testDS with the help of the len() method, as shown in Lines 51 and 52.

On Lines 55-60, we create our training dataloader (i.e., trainLoader) and test dataloader (i.e., testLoader) directly by passing our train dataset and test dataset to the Pytorch DataLoader class. We keep the shuffle parameter True in the train dataloader since we want samples from all classes to be uniformly present in a batch which is important for optimal learning and convergence of batch gradient-based optimization approaches.

Now that we have structured and defined our data loading pipeline, we will initialize our U-Net model and the training parameters.

# initialize our UNet modelunet = UNet().to(config.DEVICE)# initialize loss function and optimizerlossFunc = BCEWithLogitsLoss()opt = Adam(unet.parameters(), lr=config.INIT_LR)# calculate steps per epoch for training and test settrainSteps = len(trainDS) // config.BATCH_SIZEtestSteps = len(testDS) // config.BATCH_SIZE# initialize a dictionary to store training historyH = {"train_loss": [], "test_loss": []}

We start by defining our UNet() model on Line 63. Note that the to() function takes as input our config.DEVICE and registers our model and its parameters on the device mentioned.

On Lines 66 and 67, we define our loss function and optimizer, which we will use to train our segmentation model. The Adam optimizer class takes as input the parameters of our model (i.e., unet.parameters()) and the learning rate (i.e., config.INIT_LR) we will be using to train our model.

We then define the number of steps required to iterate over our entire train and test set, that is, trainSteps and testSteps, on Lines 70 and 71. Given that the dataloader provides our model config.BATCH_SIZE number of samples to process at a time, the number of steps required to iterate over the entire dataset (i.e., train or test set) can be calculated by dividing the total samples in the dataset by the batch size.

We also create an empty dictionary, H, on Line 74, that we will use to keep track of our training and test loss history.

Finally, we are in good shape to start understanding our training loop.

# loop over epochsprint("[INFO] training the network...")startTime = time.time()for e in tqdm(range(config.NUM_EPOCHS)):# set the model in training modeunet.train()# initialize the total training and validation losstotalTrainLoss = 0totalTestLoss = 0# loop over the training setfor (i, (x, y)) in enumerate(trainLoader):# send the input to the device(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))# perform a forward pass and calculate the training losspred = unet(x)loss = lossFunc(pred, y)# first, zero out any previously accumulated gradients, then# perform backpropagation, and then update model parametersopt.zero_grad()loss.backward()opt.step()# add the loss to the total training loss so fartotalTrainLoss += loss# switch off autogradwith torch.no_grad():# set the model in evaluation modeunet.eval()# loop over the validation setfor (x, y) in testLoader:# send the input to the device(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))# make the predictions and calculate the validation losspred = unet(x)totalTestLoss += lossFunc(pred, y)# calculate the average training and validation lossavgTrainLoss = totalTrainLoss / trainStepsavgTestLoss = totalTestLoss / testSteps# update our training historyH["train_loss"].append(avgTrainLoss.cpu().detach().numpy())H["test_loss"].append(avgTestLoss.cpu().detach().numpy())# print the model training and validation informationprint("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS))print("Train loss: {:.6f}, Test loss: {:.4f}".format(avgTrainLoss, avgTestLoss))# display the total time needed to perform the trainingendTime = time.time()print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

To time our training process, we use the time() function on Line 78. This function outputs the time when it is called. Thus, we can call it once at the start and once at the end of our training process and subtract the two outputs to get the time elapsed.

We iterate for config.NUM_EPOCHS in the training loop, as shown on Line 79. Before we start training, it is important to set our model to train mode, as we see on Line 81. This directs the PyTorch engine to track our computations and gradients and build a computational graph to backpropagate later.

We initialize variables totalTrainLoss and totalTestLoss on Lines 84 and 85 to track our losses in the given epoch. Next, on Line 88, we iterate over our trainLoader dataloader, which provides a batch of samples at a time. The training loop, as shown on Lines 88-103, comprises of the following steps:

  • First, on Line 90, we move our data samples (i.e., x and y) to the device we are training our model on, defined by config.DEVICE
  • We then pass our input image sample x through our unet model on Line 93 and get the output prediction (i.e., pred)
  • On Line 94, we compute the loss between the model prediction, pred and our ground-truth label y
  • On Lines 98-100, we backpropagate our loss through the model and update the parameters
    • This is executed with the help of three simple steps; we start by clearing all accumulated gradients from previous steps on Line 98. Next, we call the backward method on our computed loss function as shown on Line 99. This directs PyTorch to compute gradients of our loss w.r.t. all variables involved in the computation graph. Finally, we call opt.step() to update our model parameters as shown on Line 100.
  • In the end, Line 103 enables us to keep track of our training loss by adding the loss for the step to the totalTrainLoss variable, which accumulates the training loss for all samples.

This process is repeated until iterated through all dataset samples once (i.e., completed one epoch).

Once we have processed our entire training set, we would want to evaluate our model on the test set. This is helpful since it allows us to monitor the test loss and ensure that our model is not overfitting to the training set.

While evaluating our model on the test set, we do not track gradients since we will not be learning or backpropagating. Thus we can switch off the gradient computation with the help of torch.no_grad() and freeze the model weights, as shown on Line 106. This directs the PyTorch engine not to calculate and save gradients, saving memory and compute during evaluation.

We set our model to evaluation mode by calling the eval() function on Line 108. Then, we iterate through the test set samples and compute the predictions of our model on test data (Line 116). The test loss is then added to the totalTestLoss, which accumulates the test loss for the entire test set.

We then obtain the average training loss and test loss over all steps, that is, avgTrainLoss and avgTestLoss on Lines 120 and 121, and store them on Lines 124 and 125, to our dictionary, H, which we had created in the beginning to keep track of our losses.

Finally, we print the current epoch statistics, including train and test losses on Lines 128-130. This brings us to the end of one epoch, consisting of one full cycle of training on our train set and evaluation on our test set. This entire process is repeated config.NUM_EPOCHS times until our model converges.

On Lines 133 and 134, we note the end time of our training loop and subtract endTime from startTime (which we had initialized at the beginning of training) to get the total time elapsed during our network training.

# plot the training lossplt.style.use("ggplot")plt.figure()plt.plot(H["train_loss"], label="train_loss")plt.plot(H["test_loss"], label="test_loss")plt.title("Training Loss on Dataset")plt.xlabel("Epoch #")plt.ylabel("Loss")plt.legend(loc="lower left")plt.savefig(config.PLOT_PATH)# serialize the model to disktorch.save(unet, config.MODEL_PATH)

Next, we use the pyplot package of matplotlib to visualize and save our training and test loss curves on Lines 138-146. We can do this by simply passing the train_loss and test_loss keys of our loss history dictionary, H, to the plot function as shown on Lines 140 and 141. Finally, we set the title and legends of our plots (Lines 142-145) and save our visualizations on Line 146.

Finally, on Lines 149, we save the weights of our trained U-Net model with the help of the torch.save() function, which takes our trained unet model and the config.MODEL_PATH as input where we want our model to be saved.

Once our model is trained, we will see a loss trajectory plot similar to the one shown in Figure 4. Notice that train_loss gradually reduces over epochs and slowly converges. Furthermore, we see that test_loss also consistently reduces with train_loss following similar trend and values, implying our model generalizes well and is not overfitting to the training set.

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (7)

Using Our Trained U-Net Model for Prediction

Once we have trained and saved our segmentation model, we are ready to see it in action and use it for segmentation tasks.

Open the predict.py file from our project directory.

# USAGE# python predict.py# import the necessary packagesfrom pyimagesearch import configimport matplotlib.pyplot as pltimport numpy as npimport torchimport cv2import osdef prepare_plot(origImage, origMask, predMask):# initialize our figurefigure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))# plot the original image, its mask, and the predicted maskax[0].imshow(origImage)ax[1].imshow(origMask)ax[2].imshow(predMask)# set the titles of the subplotsax[0].set_title("Image")ax[1].set_title("Original Mask")ax[2].set_title("Predicted Mask")# set the layout of the figure and display itfigure.tight_layout()figure.show()

We import the necessary packages and modules as always on Lines 5-10.

To use our segmentation model for prediction, we will need a function that can take our trained model and test images, predict the output segmentation mask and finally, visualize the output predictions.

To this end, we start by defining the prepare_plot function to help us to visualize our model predictions.

This function takes as input an image, its ground-truth mask, and the segmentation output predicted by our model, that is, origImage, origMask, and predMask (Line 12) and creates a grid with a single row and three columns (Line 14) to display them (Lines 17-19).

Finally, Lines 22-24 set titles for our plots, displaying them on Lines 27 and 28.

def make_predictions(model, imagePath):# set model to evaluation modemodel.eval()# turn off gradient trackingwith torch.no_grad():# load the image from disk, swap its color channels, cast it# to float data type, and scale its pixel valuesimage = cv2.imread(imagePath)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)image = image.astype("float32") / 255.0# resize the image and make a copy of it for visualizationimage = cv2.resize(image, (128, 128))orig = image.copy()# find the filename and generate the path to ground truth# maskfilename = imagePath.split(os.path.sep)[-1]groundTruthPath = os.path.join(config.MASK_DATASET_PATH,filename)# load the ground-truth segmentation mask in grayscale mode# and resize itgtMask = cv2.imread(groundTruthPath, 0)gtMask = cv2.resize(gtMask, (config.INPUT_IMAGE_HEIGHT,config.INPUT_IMAGE_HEIGHT))

Next, we define our make_prediction function (Lines 31-77), which will take as input the path to a test image and our trained segmentation model and plot the predicted output.

Since we are only using our trained model for prediction, we start by setting our model to eval mode and switching off PyTorch gradient computation on Line 33 and Line 36, respectively.

On Lines 39-41, we load the test image (i.e., image) from imagePath using OpenCV (Line 39), convert it to RGB format (Line 40), and normalize its pixel values from the standard [0-255] to the range [0, 1], which our model is trained to process (Line 41).

The image is then resized to the standard image dimension that our model can accept on Line 44. Since we will have to modify and process the image variable before passing it through the model, we make an additional copy of it on Line 45 and store it in the orig variable, which we will use later.

On Lines 49-51, we get the path to the ground-truth mask for our test image and load the mask on Line 55. Note that we resize the mask to the same dimensions as the input image (Lines 56 and 57).

Now we process our image to a format that our model can process. Note that currently, our image has the shape [128, 128, 3]. However, our segmentation model accepts four-dimensional inputs of the format [batch_dimension, channel_dimension, height, width].

# make the channel axis to be the leading one, add a batch# dimension, create a PyTorch tensor, and flash it to the# current deviceimage = np.transpose(image, (2, 0, 1))image = np.expand_dims(image, 0)image = torch.from_numpy(image).to(config.DEVICE)# make the prediction, pass the results through the sigmoid# function, and convert the result to a NumPy arraypredMask = model(image).squeeze()predMask = torch.sigmoid(predMask)predMask = predMask.cpu().numpy()# filter out the weak predictions and convert them to integerspredMask = (predMask > config.THRESHOLD) * 255predMask = predMask.astype(np.uint8)# prepare a plot for visualizationprepare_plot(orig, gtMask, predMask)

On Line 62, we transpose the image to convert it to channel-first format, that is, [3, 128, 128], and on Line 63, we add an extra dimension using the expand_dims function of numpy to convert our image into a four-dimensional array (i.e., [1, 3, 128, 128]). Note that the first dimension here represents the batch dimension equal to one since we are processing one test image at a time. We then convert our image to a PyTorch tensor with the help of the torch.from_numpy() function and move it to the device our model is on with the help of Line 64.

Finally, on Lines 68-70, we process our test image by passing it through our model and saving the output prediction as predMask. We then apply the sigmoid activation to get our predictions in the range [0, 1]. As discussed earlier, the segmentation task is a classification problem where we have to classify the pixels in one of the two discrete classes. Since sigmoid outputs continuous values in the range [0, 1], we use our config.THRESHOLD on Line 73 to binarize our output and assign the pixels, values equal to 0 or 1. This implies that anything greater than the threshold will be assigned the value 1, and others will be assigned 0.

Since the thresholded output (i.e., (predMask > config.THRESHOLD)), now comprises of values 0 or 1, multiplying it with 255 makes the final pixel values in our predMask either 0 (i.e., pixel value for black color) or 255 (i.e., pixel value for white color). As discussed earlier, the white pixels will correspond to the region where our model has detected salt deposits, and the black pixels correspond to regions where salt is not present.

We plot our original image (i.e., orig), ground-truth mask (i.e., gtMask), and our predicted output (i.e., predMask) with the help of our prepare_plot function on Line 77. This completes the definition of our make_prediction function.

We are ready to see our model in action now.

# load the image paths in our testing file and randomly select 10# image pathsprint("[INFO] loading up test image paths...")imagePaths = open(config.TEST_PATHS).read().strip().split("\n")imagePaths = np.random.choice(imagePaths, size=10)# load our model from disk and flash it to the current deviceprint("[INFO] load up model...")unet = torch.load(config.MODEL_PATH).to(config.DEVICE)# iterate over the randomly selected test image pathsfor path in imagePaths:# make predictions and visualize the resultsmake_predictions(unet, path)

On Lines 82 and 83, we open the folder where our test image paths are stored and randomly grab 10 image paths. Line 87 loads the trained weights of our U-Net from the saved checkpoint at config.MODEL_PATH.

We finally iterate over our randomly chosen test imagePaths and predict the outputs with the help of our make_prediction function on Lines 90-92.

Figure 5 shows sample visualization outputs from our make_prediction function. The yellow region represents Class 1: Salt and the dark blue region represents Class 2: Not Salt (sediment).

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (8)

We see that in case 1 and case 2 (i.e., row 1 and row 2, respectively), our model correctly identified most of the locations containing salt deposits. However, some regions where the salt deposit exists are not identified.

However, in case 3 (i.e., row 3), our model has identified some regions as salt deposits where there is no salt (the yellow blob in the middle). This is a false positive, where our model has incorrectly predicted the positive class, that is, the presence of salt, in a region where it does not exist in the ground truth.

It is worth noting that, practically, from an application point of view, the prediction in case 3 is misleading and riskier than that in the other two cases. This is likely because for the first two cases if experts set up drillers for mining salt deposits at the predicted yellow marked locations, they will successfully find salt deposits. However, if they do the same at the location of false-positive predictions (as seen in case 3), it will waste time and resources since salt deposits do not exist at that location.

Credits

Aman Arora’s amazing article inspires our implementation of the U-Net model in the model.py file.

What's next? We recommend PyImageSearch University.

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (9)

Course information:
84 total classes • 114+ hours of on-demand code walkthrough videos • Last updated: February 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled

I strongly believe that if you had the right teacher you could master computer vision and deep learning.

Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?

That’s not the case.

All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.

If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.

Inside PyImageSearch University you'll find:

  • ✓ 84 courses on essential computer vision, deep learning, and OpenCV topics
  • ✓ 84 Certificates of Completion
  • ✓ 114+ hours of on-demand video
  • ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
  • ✓ Pre-configured Jupyter Notebooks in Google Colab
  • ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
  • ✓ Access to centralized code repos for all 536+ tutorials on PyImageSearch
  • ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
  • ✓ Access on mobile, laptop, desktop, etc.

Click here to join PyImageSearch University

Summary

In this tutorial, we learned about image segmentation and built a U-Net-based image segmentation pipeline from scratch in PyTorch.

Specifically, we discussed the architectural details and salient features of the U-Net model that make it the de-facto choice for image segmentation.

In addition, we learned how we can define our own custom dataset in PyTorch for the segmentation task at hand.

Finally, we saw how we can train our U-Net based-segmentation pipeline in PyTorch and use the trained model to make predictions on test images in real-time.

After following the tutorial, you will be able to understand the internal working of any image segmentation pipeline and build your own segmentation models from scratch in PyTorch.

To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (10)

Download the Source Code and FREE 17-page Resource Guide

Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!

U-Net: Training Image Segmentation Models in PyTorch - PyImageSearch (2024)
Top Articles
Latest Posts
Article information

Author: Kieth Sipes

Last Updated:

Views: 5744

Rating: 4.7 / 5 (67 voted)

Reviews: 90% of readers found this page helpful

Author information

Name: Kieth Sipes

Birthday: 2001-04-14

Address: Suite 492 62479 Champlin Loop, South Catrice, MS 57271

Phone: +9663362133320

Job: District Sales Analyst

Hobby: Digital arts, Dance, Ghost hunting, Worldbuilding, Kayaking, Table tennis, 3D printing

Introduction: My name is Kieth Sipes, I am a zany, rich, courageous, powerful, faithful, jolly, excited person who loves writing and wants to share my knowledge and understanding with you.