The Fish and the Transferred (Alex)Net
Implementation of Transfer Learning on a custom dataset
Introduction
-
At times we have felt that learning to perform one task makes it easier to learn a similar task. For instance, I have played Badminton all my life(or maybe teenage years of it) and when I tried to play Tennis it was quite easy to learn the nuances of the game. Similarly, In Deep Learning, we can use Transfer Learning i.e., use a PreTrained model on a similar dataset or problem for a custom dataset which is usually small in size.
-
In the previous Blog, we applied AlexNet from scratch and did not observe really good results on the Custom Dataset.
-
But, here we use a PreTrained AlexNet provided by PyTorch and observe changes in Training Time and Accuracy.
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import sklearn
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import random
import os
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import time
import seaborn as sns
import glob
from pathlib import Path
torch.manual_seed(1)
np.random.seed(1)
data_path = Path.cwd()/'Fish_Dataset/Fish_Dataset'
# Path for all the files in a 'png' format.
image_path = list(data_path.glob('**/*.png'))
# Separate Segmented from Non-Segmented Images
non_segmented_images = [img for img in image_path if 'GT' not in str(img)]
labels_non_segment = [img.parts[-3] for img in non_segmented_images]
segmented_images = [img for img in image_path if 'GT' in str(img)]
lables_segment = [img.parts[-3] for img in segmented_images]
classes = list(set(lables_segment))
# Convert String Labels to int
int_classes = {fish:i for i,fish in enumerate(classes)}
lables = [int_classes[lable] for lable in labels_non_segment]
image_data = pd.DataFrame({'Path': non_segmented_images,\
'labels': lables})
class FishDataset(Dataset):
"""Class for loading an Image."""
def __init__(self, images, labels, transform = None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img = Image.open(self.images.iloc[idx])
if self.transform:
img = self.transform(img)
label = self.labels.iloc[idx]
return img, label
train,test, train_labels, test_labels = train_test_split(image_data.Path, image_data.labels, test_size=0.2, shuffle=True)
train,val, train_labels, val_labels = train_test_split(train, train_labels, test_size=0.2, shuffle=True)
def get_loaders(train, train_labels, val, val_labels,test, test_labels, batch_size, num_workers, train_transform, test_transform):
"""
Returns the Train, Validation and Test DataLoaders.
"""
train_ds = FishDataset(images = train, labels = train_labels, transform = train_transform)
val_ds = FishDataset(images = val, labels = val_labels, transform = test_transforms)
test_ds = FishDataset(images = test, labels = test_labels, transform = test_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size,num_workers=num_workers,
shuffle= True)
val_loader = DataLoader(val_ds, batch_size=batch_size,num_workers=num_workers,
shuffle= False)
test_loader = DataLoader(test_ds, batch_size=batch_size,num_workers=num_workers,
shuffle= False)
return train_loader, val_loader, test_loader
def set_all_seeds(seed):
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def compute_accuracy(model, data_loader, device):
model.eval
with torch.no_grad():
correct_pred, num_examples = 0, 0
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.float().to(device)
logits = model(features)
_, predicted_labels = torch.max(logits, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Parameters:
------------
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
------------
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
### FISH DATASET
##########################
train_transform = transforms.Compose([transforms.Resize((64,64)),
transforms.ColorJitter(brightness=0.5, contrast=0,saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transforms = transforms.Compose([transforms.Resize((64,64)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
alexnet = models.alexnet(pretrained=True)
alexnet
for param in alexnet.parameters():
param.requires_grad = False
alexnet.classifier[1].requires_grad = True
alexnet.classifier[4].requires_grad = True
alexnet.classifier[6].requires_grad = True
alexnet.classifier[6] = nn.Linear(in_features=4096, out_features=9, bias=True)
### SETTINGS
##########################
RANDOM_SEED = 123
BATCH_SIZE = 256
NUM_EPOCHS = 10
WORKERS = 2
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
set_all_seeds(RANDOM_SEED)
train_loader, val_loader, test_loader = get_loaders(train,train_labels,val, val_labels, test,test_labels, BATCH_SIZE,WORKERS,
train_transform, test_transforms)
model = alexnet.to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1,
mode='max',
verbose=True)
logging_interval = 50
scheduler_on='minibatch_loss'
start_time = time.time()
minibatch_loss_list, train_acc_list, valid_acc_list = [],[],[]
for epoch in range(NUM_EPOCHS):
# Start Training
model.train()
for batch_idx, (features, target) in enumerate(train_loader):
features = features.to(DEVICE)
targets = target.to(DEVICE)
# Forward and BackPropagation
logits = model(features)
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
# Update Model Parameters
optimizer.step()
## LOGGING
minibatch_loss_list.append(loss.item())
if not batch_idx % logging_interval:
print(f"Epoch = {epoch+1:03d}/{NUM_EPOCHS:03d}"
f"| Batch {batch_idx:04d}/{len(train_loader):04d}"
f"| Loss: {loss:.4f}")
## Validation
model.eval()
with torch.no_grad():
train_acc = compute_accuracy(model, train_loader, DEVICE)
valid_acc = compute_accuracy(model, val_loader, DEVICE)
print(f'Epoch: {epoch+1}/{NUM_EPOCHS:03d} '
f'| Train: {train_acc :.2f}% '
f'| Validation: {valid_acc :.2f}%')
train_acc_list.append(train_acc)
valid_acc_list.append(valid_acc)
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
if scheduler is not None:
if scheduler_on == "valid_acc":
scheduler.step(valid_acc_list[-1])
if scheduler_on == 'minibatch_loss':
scheduler.step(minibatch_loss_list[-1])
else:
raise ValueError("Invalid `scheduler_on` choice")
total_elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {total_elapsed:.2f} min')
# Compute Test Accuracy
test_acc = compute_accuracy(model, test_loader, device=DEVICE)
print(f"Test accuracy: {test_acc:0.3f}")
References
1. Code:
Sebastian Raschka's implementation on CIFAR-10 is a great start
2. Theory:
D2L's explanation is concise and simple for initial reader's.