Introduction

  • At times we have felt that learning to perform one task makes it easier to learn a similar task. For instance, I have played Badminton all my life(or maybe teenage years of it) and when I tried to play Tennis it was quite easy to learn the nuances of the game. Similarly, In Deep Learning, we can use Transfer Learning i.e., use a PreTrained model on a similar dataset or problem for a custom dataset which is usually small in size.

  • In the previous Blog, we applied AlexNet from scratch and did not observe really good results on the Custom Dataset.

  • But, here we use a PreTrained AlexNet provided by PyTorch and observe changes in Training Time and Accuracy.

Libraries

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

import sklearn
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import random
import os
import matplotlib.pyplot as plt
import PIL
from PIL import Image

import time
import seaborn as sns
import glob
from pathlib import Path
torch.manual_seed(1)
np.random.seed(1)
data_path = Path.cwd()/'Fish_Dataset/Fish_Dataset'

# Path for all the files in a 'png' format.
image_path = list(data_path.glob('**/*.png')) 

# Separate Segmented from Non-Segmented Images

non_segmented_images = [img for img in image_path if 'GT' not in str(img)]
labels_non_segment = [img.parts[-3] for img in non_segmented_images]

segmented_images = [img for img in image_path if 'GT' in str(img)]
lables_segment = [img.parts[-3] for img in segmented_images]

classes = list(set(lables_segment))

# Convert String Labels to int

int_classes = {fish:i for i,fish in enumerate(classes)}

lables = [int_classes[lable] for lable in labels_non_segment]
image_data = pd.DataFrame({'Path': non_segmented_images,\
              'labels': lables})
class FishDataset(Dataset):
  """Class for loading an Image."""
  def __init__(self, images, labels, transform = None):
    self.images = images
    self.labels = labels
    self.transform = transform

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    img = Image.open(self.images.iloc[idx])

    if self.transform:
      img = self.transform(img)
    label = self.labels.iloc[idx]
    return img, label
train,test, train_labels, test_labels = train_test_split(image_data.Path, image_data.labels, test_size=0.2, shuffle=True)

train,val, train_labels, val_labels = train_test_split(train, train_labels, test_size=0.2, shuffle=True)
def get_loaders(train, train_labels, val, val_labels,test, test_labels, batch_size, num_workers, train_transform, test_transform):
  """
  Returns the Train, Validation and Test DataLoaders.
  """

  train_ds = FishDataset(images = train, labels = train_labels, transform = train_transform)
  val_ds = FishDataset(images = val, labels = val_labels, transform = test_transforms)
  test_ds = FishDataset(images = test, labels = test_labels, transform = test_transforms)

  train_loader = DataLoader(train_ds, batch_size=batch_size,num_workers=num_workers,
                            shuffle= True)
  val_loader = DataLoader(val_ds, batch_size=batch_size,num_workers=num_workers,
                            shuffle= False)
  test_loader = DataLoader(test_ds, batch_size=batch_size,num_workers=num_workers,
                          shuffle= False)
  return train_loader, val_loader, test_loader

def set_all_seeds(seed):
  os.environ["PL_GLOBAL_SEED"] = str(seed)
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

def compute_accuracy(model, data_loader, device):
  model.eval
  with torch.no_grad():
    correct_pred, num_examples = 0, 0

    for i, (features, targets) in enumerate(data_loader):
      features = features.to(device)
      targets = targets.float().to(device)

      logits = model(features)
      _, predicted_labels = torch.max(logits, 1)

      num_examples += targets.size(0)
      correct_pred += (predicted_labels == targets).sum()
  return correct_pred.float()/num_examples * 100


class UnNormalize(object):
  def __init__(self, mean, std):
    self.mean = mean
    self.std = std
  def __call__(self, tensor):
    """
    Parameters:
    ------------
    tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        
    Returns:
    ------------
    Tensor: Normalized image.
    """
    for t, m, s in zip(tensor, self.mean, self.std):
      t.mul_(s).add_(m)
    return tensor
### FISH DATASET
##########################

train_transform = transforms.Compose([transforms.Resize((64,64)),
                                      transforms.ColorJitter(brightness=0.5, contrast=0,saturation=0, hue=0),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                      ])

test_transforms = transforms.Compose([transforms.Resize((64,64)),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

Transfer Learning

Step-1: PreTrain a Neural Network

  • We use a Pre-trained AlexNet on ImageNet dataset as the source model
alexnet = models.alexnet(pretrained=True)
alexnet
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Step-2: Create a Target Model

  • Except the classifier layer, we copy all the model parameters.
for param in alexnet.parameters():
  param.requires_grad = False

Step-3: Fine-Tuning the Target Model

  1. We fine-tune classifier layer:
  2. We add the Output layer to the final layer of Multilayer Perceptron in classifier layer
alexnet.classifier[1].requires_grad = True
alexnet.classifier[4].requires_grad = True
alexnet.classifier[6].requires_grad = True
alexnet.classifier[6] = nn.Linear(in_features=4096, out_features=9, bias=True)

Step-4: Train target model on Target Dataset

  1. We Train the output layer i.e., classifier from scratch.
  2. We update the parameters of all the input layers except the output layer.
### SETTINGS
##########################

RANDOM_SEED = 123
BATCH_SIZE = 256
NUM_EPOCHS = 10
WORKERS = 2
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
set_all_seeds(RANDOM_SEED)

train_loader, val_loader, test_loader = get_loaders(train,train_labels,val, val_labels, test,test_labels, BATCH_SIZE,WORKERS,
                                                    train_transform, test_transforms)

model = alexnet.to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.1)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='max',
                                                       verbose=True)
logging_interval = 50
scheduler_on='minibatch_loss'
start_time = time.time()

minibatch_loss_list, train_acc_list, valid_acc_list = [],[],[]

for epoch in range(NUM_EPOCHS):
  # Start Training
  model.train()
  for batch_idx, (features, target) in enumerate(train_loader):
    features = features.to(DEVICE)
    targets = target.to(DEVICE)
    # Forward and BackPropagation
    logits = model(features)
    loss = F.cross_entropy(logits, targets)
    optimizer.zero_grad()
    loss.backward()

    # Update Model Parameters
    optimizer.step()

    ## LOGGING
    minibatch_loss_list.append(loss.item())
    if not batch_idx % logging_interval:
      print(f"Epoch = {epoch+1:03d}/{NUM_EPOCHS:03d}"
      f"| Batch {batch_idx:04d}/{len(train_loader):04d}"
      f"| Loss: {loss:.4f}")
    
  ## Validation
  model.eval()
  with torch.no_grad():
    train_acc = compute_accuracy(model, train_loader, DEVICE)
    valid_acc = compute_accuracy(model, val_loader, DEVICE)
    print(f'Epoch: {epoch+1}/{NUM_EPOCHS:03d} '
    f'| Train: {train_acc :.2f}% '
    f'| Validation: {valid_acc :.2f}%')
    train_acc_list.append(train_acc)
    valid_acc_list.append(valid_acc)
    
  elapsed = (time.time() - start_time)/60
  print(f'Time elapsed: {elapsed:.2f} min')

  if scheduler is not None:
    if scheduler_on == "valid_acc":
      scheduler.step(valid_acc_list[-1])
    if scheduler_on == 'minibatch_loss':
      scheduler.step(minibatch_loss_list[-1])
    else:
      raise ValueError("Invalid `scheduler_on` choice")

total_elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {total_elapsed:.2f} min')

# Compute Test Accuracy
test_acc = compute_accuracy(model, test_loader, device=DEVICE)

print(f"Test accuracy: {test_acc:0.3f}")
Epoch = 001/010| Batch 0000/0023| Loss: 163.3353
Epoch: 1/010 | Train: 88.77% | Validation: 88.06%
Time elapsed: 1.91 min
Epoch = 002/010| Batch 0000/0023| Loss: 72.6810
Epoch: 2/010 | Train: 89.51% | Validation: 88.89%
Time elapsed: 3.82 min
Epoch = 003/010| Batch 0000/0023| Loss: 63.1871
Epoch: 3/010 | Train: 90.26% | Validation: 89.31%
Time elapsed: 5.73 min
Epoch = 004/010| Batch 0000/0023| Loss: 62.3928
Epoch: 4/010 | Train: 91.15% | Validation: 89.86%
Time elapsed: 7.65 min
Epoch = 005/010| Batch 0000/0023| Loss: 47.7385
Epoch: 5/010 | Train: 92.31% | Validation: 90.56%
Time elapsed: 9.56 min
Epoch = 006/010| Batch 0000/0023| Loss: 62.1415
Epoch: 6/010 | Train: 92.20% | Validation: 90.35%
Time elapsed: 11.48 min
Epoch = 007/010| Batch 0000/0023| Loss: 46.1223
Epoch: 7/010 | Train: 92.71% | Validation: 91.60%
Time elapsed: 13.40 min
Epoch = 008/010| Batch 0000/0023| Loss: 44.8405
Epoch: 8/010 | Train: 92.74% | Validation: 90.90%
Time elapsed: 15.32 min
Epoch = 009/010| Batch 0000/0023| Loss: 37.8981
Epoch: 9/010 | Train: 93.18% | Validation: 91.46%
Time elapsed: 17.24 min
Epoch = 010/010| Batch 0000/0023| Loss: 29.6452
Epoch: 10/010 | Train: 93.77% | Validation: 92.22%
Time elapsed: 19.16 min
Total Training Time: 19.16 min
Test accuracy: 91.722

Comments

  • Training Time was reduced by 11 mins!
  • Test Accuracy within 10 epochs incresed to 91.722% from 11%. So, Fine-Tuning did help in increasing AlexNet's generalizanility on the source dataset.

References

1. Code:

Sebastian Raschka's implementation on CIFAR-10 is a great start

2. Theory:

D2L's explanation is concise and simple for initial reader's.