Building, training, and evaluating a semantic segmentation model is demonstrated. A simplified version of the U-Net architecture, a popular choice for segmentation tasks (particularly in domains like medical imaging), is implemented using PyTorch. While U-Net is the primary focus, the principles are applicable to many other architectures like FCNs or DeepLab.We assume you have a working Python environment with PyTorch, TorchVision, and libraries like NumPy and Matplotlib installed.1. Preparing the DataSemantic segmentation requires images and corresponding pixel-level masks. Each pixel in the mask is labeled with the class it belongs to (e.g., 0 for background, 1 for road, 2 for building).For this exercise, you might use a standard dataset like Pascal VOC, Cityscapes, or even create a simple synthetic dataset. Let's assume we have a dataset directory structure like this:data/ ├── images/ │ ├── 0001.png │ ├── 0002.png │ └── ... └── masks/ ├── 0001.png ├── 0002.png └── ...We'll need a custom PyTorch Dataset class to load images and masks.import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import numpy as np class SegmentationDataset(Dataset): def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None): self.image_dir = image_dir self.mask_dir = mask_dir self.image_filenames = sorted(os.listdir(image_dir)) self.mask_filenames = sorted(os.listdir(mask_dir)) self.transform = transform self.mask_transform = mask_transform # Basic check: ensure image and mask lists match assert len(self.image_filenames) == len(self.mask_filenames), \ "Number of images and masks must be the same." # Optionally add more rigorous checks here (e.g., matching filenames) def __len__(self): return len(self.image_filenames) def __getitem__(self, idx): img_path = os.path.join(self.image_dir, self.image_filenames[idx]) mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx]) image = Image.open(img_path).convert("RGB") mask = Image.open(mask_path).convert("L") # Assuming mask is grayscale if self.transform: image = self.transform(image) if self.mask_transform: # Important: Apply geometric transforms identically to image and mask # but avoid normalizing the mask values like the image. # Often requires careful handling of random transformations. # For simplicity here, assume basic resize/tensor conversion. mask = self.mask_transform(mask) # Convert mask to LongTensor for CrossEntropyLoss mask = mask.squeeze(0).long() else: # Default conversion if no specific mask transform mask = torch.from_numpy(np.array(mask)).long() return image, mask # Define transformations (adjust size and normalization as needed) image_transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) mask_transform = transforms.Compose([ transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST), # Use NEAREST for masks transforms.ToTensor() ]) # Create Datasets and DataLoaders # Replace with your actual data paths train_dataset = SegmentationDataset('data/images', 'data/masks', transform=image_transform, mask_transform=mask_transform) # val_dataset = SegmentationDataset('data_val/images', 'data_val/masks', transform=image_transform, mask_transform=mask_transform) # For validation train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4) # val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)Note the use of transforms.InterpolationMode.NEAREST for resizing masks. This prevents interpolation from creating invalid class labels between existing ones. Mask tensors should typically be of type LongTensor.2. Defining the U-Net ModelLet's implement a simplified U-Net. It consists of an encoder (contracting path) that captures context and a decoder (expansive path) that enables precise localization using transposed convolutions and skip connections.import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): """(Convolution => BatchNorm => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e5474a7ae105f32e70a5168b # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits # Instantiate the model # n_channels=3 for RGB images, n_classes = number of segmentation classes (e.g., 2 for binary) num_classes = 2 # Example: Background + Foreground model = UNet(n_channels=3, n_classes=num_classes) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device)This U-Net implementation uses standard convolutional blocks, max-pooling for downsampling, and optionally bilinear upsampling or transposed convolutions for upsampling. The skip connections concatenate feature maps from the encoder path with the upsampled feature maps in the decoder path, helping to recover fine-grained details lost during downsampling.3. Loss Function and OptimizerFor semantic segmentation with multiple classes, the standard loss function is Cross-Entropy Loss applied pixel-wise. Each pixel is treated as a classification problem. If your dataset is highly imbalanced (e.g., small objects in large backgrounds), you might consider weighted cross-entropy or Dice Loss.$$ \text{CrossEntropyLoss}(output, target) = -\sum_{c=1}^{C} target_c \log(\text{softmax}(output)_c) $$Where $C$ is the number of classes, $output$ are the raw logits from the model for a pixel, and $target$ is the one-hot encoded ground truth label for that pixel (though PyTorch's nn.CrossEntropyLoss handles integer targets directly).We'll use the Adam optimizer.import torch.optim as optim # Loss Function # `ignore_index` can be useful if you have a label to ignore (e.g., border pixels) criterion = nn.CrossEntropyLoss()#ignore_index=255) # Optimizer optimizer = optim.Adam(model.parameters(), lr=1e-4) # Learning rate might need tuning4. Training LoopThe training loop iterates through the dataset, performs forward and backward passes, and updates the model weights.num_epochs = 25 # Adjust as needed train_losses = [] model.train() # Set model to training mode for epoch in range(num_epochs): running_loss = 0.0 for i, (images, masks) in enumerate(train_loader): images = images.to(device) masks = masks.to(device) # Shape: [batch_size, H, W] # Zero the parameter gradients optimizer.zero_grad() # Forward pass outputs = model(images) # Shape: [batch_size, num_classes, H, W] # Calculate loss loss = criterion(outputs, masks) # Backward pass and optimize loss.backward() optimizer.step() running_loss += loss.item() if (i + 1) % 50 == 0: # Print status every 50 mini-batches print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}') epoch_loss = running_loss / len(train_loader) train_losses.append(epoch_loss) print(f'Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {epoch_loss:.4f}') print('Finished Training') # Save the trained model (optional) # torch.save(model.state_dict(), 'unet_segmentation_model.pth')This is a basic training loop. In practice, you would add:Validation loop to monitor performance on unseen data.Learning rate scheduling.Calculation of evaluation metrics like IoU within the validation loop.Saving model checkpoints.5. Evaluation Metric: Intersection over Union (IoU)The most common metric for segmentation is Intersection over Union (IoU), also known as the Jaccard Index. It measures the overlap between the predicted segmentation mask ($A$) and the ground truth mask ($B$) for a specific class.$$ IoU = J(A, B) = \frac{|A \cap B|}{|A \cup B|} = \frac{\text{Intersection Area}}{\text{Union Area}} $$Mean IoU (mIoU) is often reported, which is the average IoU calculated over all classes.def calculate_iou(pred, target, num_classes, smooth=1e-6): """Calculates IoU for each class.""" pred = torch.argmax(pred, dim=1) # Convert logits to predicted class indices [B, H, W] pred = pred.contiguous().view(-1) target = target.contiguous().view(-1) iou_per_class = [] for clas in range(num_classes): # Calculate IoU for each class pred_inds = (pred == clas) target_inds = (target == clas) intersection = (pred_inds[target_inds]).long().sum().item() # Correct intersection calc union = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection if union == 0: # If there is no ground truth or prediction, score is 1 if both empty, 0 otherwise iou_per_class.append(float('nan')) # or 0 or 1 according to convention else: iou = (intersection + smooth) / (union + smooth) iou_per_class.append(iou) return np.array(iou_per_class) def calculate_miou(pred_loader, model, num_classes, device): """Calculates mean IoU over a dataset.""" model.eval() # Set model to evaluation mode total_iou = np.zeros(num_classes) num_samples = 0 with torch.no_grad(): for images, masks in pred_loader: images = images.to(device) masks = masks.to(device) # Ground truth masks outputs = model(images) # Model predictions (logits) iou = calculate_iou(outputs.cpu(), masks.cpu(), num_classes) # Handle NaN values if a class is not present in the batch # For a mIoU, accumulate intersection and union counts across batches # This simplified version averages batch IoUs, which can be less accurate. valid_iou = iou[~np.isnan(iou)] if len(valid_iou) > 0: total_iou[:len(valid_iou)] += valid_iou # Accumulate IoU per class num_samples += 1 # Count batches with valid IoU scores # Calculate mean IoU, ignoring NaNs from classes absent in the dataset partition mean_iou_per_class = total_iou / num_samples mean_iou = np.nanmean(mean_iou_per_class) # Average across classes that were present print(f'Mean IoU across {num_samples} samples: {mean_iou:.4f}') print(f'IoU per class: {mean_iou_per_class}') return mean_iou # Example usage after training (assuming you have a val_loader) # mIoU = calculate_miou(val_loader, model, num_classes, device)Implementing a mIoU calculation often involves accumulating the intersection and union counts per class across all batches before dividing, rather than averaging per-batch IoUs, especially when classes might be absent in some batches.6. Visualizing PredictionsVisualizing the model's output helps understand its performance qualitatively.import matplotlib.pyplot as plt def visualize_predictions(dataset, model, device, num_samples=5): model.eval() samples_shown = 0 fig, axes = plt.subplots(num_samples, 3, figsize=(10, num_samples * 3)) fig.suptitle("Image / Ground Truth / Prediction") # Use the dataset directly to get raw images and masks before normalization vis_dataset = SegmentationDataset('data/images', 'data/masks', transform=transforms.Compose([transforms.Resize((128, 128))]), mask_transform=transforms.Compose([transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST)])) # Get normalized images for model input input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) with torch.no_grad(): for i in range(len(vis_dataset)): if samples_shown >= num_samples: break raw_image, raw_mask = vis_dataset[i] # Get raw PIL images/arrays input_image = input_transform(raw_image).unsqueeze(0).to(device) # Prepare for model output = model(input_image) pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy() axes[samples_shown, 0].imshow(raw_image) axes[samples_shown, 0].set_title("Image") axes[samples_shown, 0].axis('off') axes[samples_shown, 1].imshow(raw_mask, cmap='gray') # Adjust cmap if needed axes[samples_shown, 1].set_title("Ground Truth") axes[samples_shown, 1].axis('off') axes[samples_shown, 2].imshow(pred_mask, cmap='gray') # Adjust cmap based on num_classes axes[samples_shown, 2].set_title("Prediction") axes[samples_shown, 2].axis('off') samples_shown += 1 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap plt.show() # Example usage: # visualize_predictions(train_dataset, model, device) # Use train or val datasetThis visualization function shows the original image, the ground truth mask, and the model's predicted mask side-by-side for a few examples.Optionally, you can plot the training loss curve to check convergence:{"data":[{"type":"scatter","mode":"lines","name":"Training Loss","x":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25],"y":[1.53,1.12,0.85,0.67,0.55,0.48,0.42,0.38,0.35,0.32,0.30,0.28,0.26,0.25,0.24,0.23,0.22,0.21,0.20,0.19,0.18,0.17,0.16,0.15,0.14],"line":{"color":"#4263eb"}}],"layout":{"title":{"text":"Training Loss per Epoch"},"xaxis":{"title":{"text":"Epoch"}},"yaxis":{"title":{"text":"Average Loss"}},"template":"plotly_white"}}Training loss curve showing a decrease over 25 epochs.Next Steps and ExperimentationThis practical provides a starting point. To improve your segmentation model, consider:Data Augmentation: Apply spatial and color augmentations (e.g., rotations, flips, brightness changes) to the images and masks consistently.More Complex Architectures: Implement or use pre-built versions of DeepLabV3+ or other advanced models.Transfer Learning: Initialize the encoder part of your U-Net with weights pre-trained on a large dataset like ImageNet.Different Loss Functions: Experiment with Dice Loss or Focal Loss, especially for imbalanced datasets.Hyperparameter Tuning: Systematically tune the learning rate, batch size, optimizer, and network depth.Post-processing: Apply techniques like Conditional Random Fields (CRFs) to refine the predicted segmentation boundaries.Building effective segmentation models involves careful data preparation, appropriate architecture selection, correct loss implementation, and thorough evaluation. This hands-on exercise provides the fundamental building blocks for tackling diverse segmentation challenges.